diff --git a/mlir/include/air/Dialect/AIR/AIR.td b/mlir/include/air/Dialect/AIR/AIR.td index 5e02c414c..a7ece1568 100644 --- a/mlir/include/air/Dialect/AIR/AIR.td +++ b/mlir/include/air/Dialect/AIR/AIR.td @@ -426,21 +426,21 @@ def air_ChannelOp : air_Op<"channel", [Symbol]>, int broadcastNum = 1; if (isBroadcast()) for (auto bShape : getOperation()->getAttrOfType("broadcast_shape")) { - auto attr = bShape.dyn_cast().getInt(); + auto attr = llvm::dyn_cast(bShape).getInt(); broadcastNum *= attr; } return broadcastNum; } int getBufferResources() { if(auto attr = getOperation()->getAttrOfType("buffer_resources")) - return attr.dyn_cast().getInt(); + return llvm::dyn_cast(attr).getInt(); else return 1; } int getBundleSize() { int size = 1; for (auto i : getSize()) - size *= i.dyn_cast().getInt(); + size *= llvm::dyn_cast(i).getInt(); return size; } }]; diff --git a/mlir/include/air/Dialect/AIR/AIROpBase.td b/mlir/include/air/Dialect/AIR/AIROpBase.td index 0bea19583..44be67199 100644 --- a/mlir/include/air/Dialect/AIR/AIROpBase.td +++ b/mlir/include/air/Dialect/AIR/AIROpBase.td @@ -88,7 +88,7 @@ def MemorySpace: I32EnumAttr<"MemorySpace", "AIR Memory Space IDs", } def air_AsyncToken : DialectType< - air_Dialect, CPred<"$_self.isa()">, "async token type">, + air_Dialect, CPred<"llvm::isa($_self)">, "async token type">, BuildableType<"xilinx::air::AsyncTokenType::get($_builder.getContext())">; def air_AsyncOpInterface : OpInterface<"AsyncOpInterface"> { diff --git a/mlir/include/air/Dialect/AIRRt/AIRRtBase.td b/mlir/include/air/Dialect/AIRRt/AIRRtBase.td index b111f1d2b..8a3a3f8c7 100644 --- a/mlir/include/air/Dialect/AIRRt/AIRRtBase.td +++ b/mlir/include/air/Dialect/AIRRt/AIRRtBase.td @@ -23,7 +23,7 @@ can be lowered to a combination of standard and LLVM dialects. } def AIRRt_Event : DialectType< - AIRRt_Dialect, CPred<"$_self.isa()">, "event type">, + AIRRt_Dialect, CPred<"llvm::isa($_self)">, "event type">, BuildableType<"xilinx::airrt::EventType::get($_builder.getContext())">; #endif // #ifndef AIRRT_BASE diff --git a/mlir/lib/CAPI/Dialects.cpp b/mlir/lib/CAPI/Dialects.cpp index 5de2341a2..ece5d660a 100644 --- a/mlir/lib/CAPI/Dialects.cpp +++ b/mlir/lib/CAPI/Dialects.cpp @@ -18,7 +18,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AIR, air, xilinx::air::airDialect) //===---------------------------------------------------------------------===// bool mlirTypeIsAIRAsyncTokenType(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirAIRAsyncTokenTypeGet(MlirContext ctx) { diff --git a/mlir/lib/Conversion/AIRLoweringPass.cpp b/mlir/lib/Conversion/AIRLoweringPass.cpp index f9a5f2103..8fa2c1161 100644 --- a/mlir/lib/Conversion/AIRLoweringPass.cpp +++ b/mlir/lib/Conversion/AIRLoweringPass.cpp @@ -101,7 +101,7 @@ class AIRLaunchConversion : public ConversionPattern { rewriter.setInsertionPoint(scfPar); SmallVector deps; for (auto &o : operands) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); rewriter.replaceOpWithNewOp( op, airrt::EventType::get(op->getContext()), deps); @@ -170,7 +170,8 @@ class AIRSegmentConversion : public ConversionPattern { rewriter.clone(o, remap); } else if (auto chanOp = dyn_cast(o)) { // clone L3 get/put - MemRefType memrefTy = chanOp.getMemref().getType().cast(); + MemRefType memrefTy = + llvm::cast(chanOp.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3) { rewriter.clone(o, remap); continue; @@ -190,7 +191,7 @@ class AIRSegmentConversion : public ConversionPattern { SmallVector deps; for (auto &o : operands) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); if (op->getNumResults()) { rewriter.setInsertionPoint(op); @@ -231,7 +232,7 @@ class AIRHerdConversion : public ConversionPattern { SmallVector deps; for (auto &o : operands) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); if (op->getNumResults()) { auto w = rewriter.create( @@ -320,7 +321,7 @@ class AIRPipelineGetConversion : public ConversionPattern { auto getOp = cast(op); SmallVector gets; for (auto r : getOp.getResults()) { - if (auto ty = r.getType().dyn_cast()) + if (auto ty = llvm::dyn_cast(r.getType())) gets.push_back(rewriter.create( op->getLoc(), ty, ValueRange{})); else @@ -361,14 +362,14 @@ class AIRDmaMemcpyNdToAIRRtConversion SmallVector deps; for (auto o : adaptor.getOperands()) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); if (deps.size()) rewriter.create( op->getLoc(), airrt::EventType::get(op->getContext()), deps); - MemRefType src = op.getSrcMemref().getType().cast(); - MemRefType dst = op.getDstMemref().getType().cast(); + MemRefType src = llvm::cast(op.getSrcMemref().getType()); + MemRefType dst = llvm::cast(op.getDstMemref().getType()); bool isFromTile = false; bool isFullMemcpy = false; if (src.getMemorySpaceAsInt() == (int)air::MemorySpace::L1 && @@ -493,7 +494,8 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder, auto loc = thisOp->getLoc(); auto ctx = thisOp->getContext(); - MemRefType thisMemrefType = thisOp.getMemref().getType().cast(); + MemRefType thisMemrefType = + llvm::cast(thisOp.getMemref().getType()); bool thisOpIsInShim = thisMemrefType.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L3; @@ -523,9 +525,9 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder, // Broadcast channel control loop assert(theOtherOp->hasAttr("tile")); ArrayAttr tiles = theOtherOp->getAttrOfType("tile"); - auto tile_dict = tiles[0].cast(); - auto row = tile_dict.get("row").cast().getInt(); - auto col = tile_dict.get("col").cast().getInt(); + auto tile_dict = llvm::cast(tiles[0]); + auto row = llvm::cast(tile_dict.get("row")).getInt(); + auto col = llvm::cast(tile_dict.get("col")).getInt(); opers.push_back(builder.create( loc, i64Ty, IntegerAttr::get(i64Ty, col))); opers.push_back(builder.create( @@ -627,7 +629,7 @@ class AIRChannelPutToAIRRtConversion // Resolve channel op's dependency list SmallVector deps; for (auto o : adaptor.getOperands()) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); if (deps.size()) rewriter.replaceOpWithNewOp( @@ -672,7 +674,7 @@ class AIRChannelGetToAIRRtConversion // Resolve channel op's dependency list SmallVector deps; for (auto o : adaptor.getOperands()) - if (o.getType().isa()) + if (llvm::isa(o.getType())) deps.push_back(o); if (deps.size()) rewriter.replaceOpWithNewOp( @@ -712,7 +714,7 @@ class L2DeallocToAIRRtConversion : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dealloc = cast(op); - auto type = dealloc.getMemref().getType().cast(); + auto type = llvm::cast(dealloc.getMemref().getType()); if (type.getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { rewriter.replaceOpWithNewOp(op, SmallVector{}, op->getOperands()); @@ -772,7 +774,7 @@ class ScfYieldOpConversion : public OpConversionPattern { SmallVector operands{adaptor.getOperands()}; SmallVector retTys; for (auto t : op->getResultTypes()) { - if (t.isa()) { + if (llvm::isa(t)) { retTys.push_back(airrt::EventType::get(op->getContext())); } else { retTys.push_back(t); @@ -813,7 +815,7 @@ class ScfReduceOpConversion : public OpConversionPattern { SmallVector opers; for (int i = 0, e = o.getNumOperands(); i < e; i++) { auto oper = remap.lookupOrDefault(o.getOperand(i)); - if (oper.getType().isa()) { + if (llvm::isa(oper.getType())) { auto ty = airrt::EventType::get(o.getContext()); auto cast = rewriter.create( op->getLoc(), ty, oper); @@ -838,7 +840,7 @@ class ScfReduceReturnOpConversion SmallVector operands{adaptor.getOperands()}; SmallVector retTys; for (auto t : op->getResultTypes()) { - if (t.isa()) { + if (llvm::isa(t)) { retTys.push_back(airrt::EventType::get(op->getContext())); } else { retTys.push_back(t); @@ -859,7 +861,7 @@ class ScfIfOpConversion : public OpConversionPattern { SmallVector retTys; for (auto t : op->getResultTypes()) { - if (t.isa()) { + if (llvm::isa(t)) { retTys.push_back(airrt::EventType::get(op->getContext())); } else { retTys.push_back(t); @@ -917,7 +919,7 @@ class ScfForOpConversion : public OpConversionPattern { SmallVector opers; for (int i = 0, e = o.getNumOperands(); i < e; i++) { auto oper = remap.lookupOrDefault(o.getOperand(i)); - if (oper.getType().isa()) { + if (llvm::isa(oper.getType())) { auto ty = airrt::EventType::get(o.getContext()); auto cast = rewriter.create( op->getLoc(), ty, oper); @@ -937,7 +939,7 @@ class ScfForOpConversion : public OpConversionPattern { rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (auto res : newOp->getResults()) { - if (res.getType().isa()) { + if (llvm::isa(res.getType())) { auto ty = air::AsyncTokenType::get(op->getContext()); auto cast = rewriter.create(op->getLoc(), ty, res); @@ -958,7 +960,7 @@ class ScfParOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { SmallVector newInitVals; for (auto initVal : adaptor.getInitVals()) { - if (initVal.getType().isa()) { + if (llvm::isa(initVal.getType())) { auto cast = rewriter.create( op->getLoc(), airrt::EventType::get(op->getContext()), initVal); newInitVals.push_back(cast.getResult(0)); @@ -996,7 +998,7 @@ class ScfParOpConversion : public OpConversionPattern { rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (auto res : newOp->getResults()) { - if (res.getType().isa()) { + if (llvm::isa(res.getType())) { auto ty = air::AsyncTokenType::get(op->getContext()); auto cast = rewriter.create(op->getLoc(), ty, res); @@ -1086,7 +1088,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { TypeConverter converter; converter.addConversion([&](Type type) -> std::optional { // convert !air.async.token to !airrt.event - if (auto t = type.dyn_cast()) + if (auto t = llvm::dyn_cast(type)) return airrt::EventType::get(context); else return type; @@ -1147,14 +1149,13 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { }); target.addDynamicallyLegalOp([&](memref::DeallocOp op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() != - (int)air::MemorySpace::L2); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() != (int)air::MemorySpace::L2); }); target.addDynamicallyLegalOp([&](scf::ForOp op) { for (auto o : op.getRegionIterArgs()) { - if (o.getType().isa()) + if (llvm::isa(o.getType())) return false; } return true; @@ -1162,7 +1163,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { target.addDynamicallyLegalOp([&](scf::ParallelOp op) { for (auto v : op.getResults()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; @@ -1170,7 +1171,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { target.addDynamicallyLegalOp([&](scf::YieldOp op) { for (auto v : op.getResults()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; @@ -1178,14 +1179,14 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { target.addDynamicallyLegalOp([&](scf::ReduceOp op) { for (auto o : op.getOperands()) - if (o.getType().isa()) + if (llvm::isa(o.getType())) return false; return true; }); target.addDynamicallyLegalOp( [&](scf::ReduceReturnOp op) { - if (op.getResult().getType().isa()) + if (llvm::isa(op.getResult().getType())) return false; else return true; @@ -1193,7 +1194,7 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase { target.addDynamicallyLegalOp([&](scf::IfOp op) { for (auto v : op.getResults()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; diff --git a/mlir/lib/Conversion/AIRPipeline.cpp b/mlir/lib/Conversion/AIRPipeline.cpp index ebf93a553..8e8484174 100644 --- a/mlir/lib/Conversion/AIRPipeline.cpp +++ b/mlir/lib/Conversion/AIRPipeline.cpp @@ -64,7 +64,7 @@ LogicalResult AIRPipeStageConversion::matchAndRewrite( // For each output of the pipeline stage, create a buffer + store SmallVector bufs; for (auto o : yield.getOperands()) { - if (RankedTensorType tt = o.getType().dyn_cast()) { + if (RankedTensorType tt = llvm::dyn_cast(o.getType())) { auto memrefTy = MemRefType::get(tt.getShape(), tt.getElementType()); rewriter.setInsertionPoint(aif); auto buf = rewriter.create(op->getLoc(), memrefTy); @@ -83,7 +83,7 @@ LogicalResult AIRPipeStageConversion::matchAndRewrite( SmallVector bufs; rewriter.setInsertionPoint(aif); for (auto o : yield.getOperands()) { - if (RankedTensorType tt = o.getType().dyn_cast()) { + if (RankedTensorType tt = llvm::dyn_cast(o.getType())) { rewriter.setInsertionPoint(&yield); auto idValPlus = rewriter.create(op->getLoc(), id + 1); diff --git a/mlir/lib/Conversion/AIRRtToLLVMPass.cpp b/mlir/lib/Conversion/AIRRtToLLVMPass.cpp index 48c95b0fe..940a84cee 100644 --- a/mlir/lib/Conversion/AIRRtToLLVMPass.cpp +++ b/mlir/lib/Conversion/AIRRtToLLVMPass.cpp @@ -444,14 +444,16 @@ class ModuleMetadataToLLVMConversion herd_meta->getAttrOfType("dma_allocations"); assert(shim_attr); for (auto &shim_alloc : shim_attr) { - auto shim_alloc_dict = shim_alloc.cast(); - auto id = shim_alloc_dict.get("id").cast().getInt(); - auto row = shim_alloc_dict.get("row").cast().getInt(); - auto col = shim_alloc_dict.get("col").cast().getInt(); + auto shim_alloc_dict = llvm::cast(shim_alloc); + auto id = llvm::cast(shim_alloc_dict.get("id")).getInt(); + auto row = + llvm::cast(shim_alloc_dict.get("row")).getInt(); + auto col = + llvm::cast(shim_alloc_dict.get("col")).getInt(); auto channel = - shim_alloc_dict.get("channel").cast().getInt(); + llvm::cast(shim_alloc_dict.get("channel")).getInt(); auto location = - shim_alloc_dict.get("location").cast().getInt(); + llvm::cast(shim_alloc_dict.get("location")).getInt(); cols[id - 1][row][col] = location; chans[id - 1][row][col] = channel; } @@ -579,7 +581,7 @@ LogicalResult lowerDmaNdMemcpy(Operation *op, PatternRewriter &rewriter, operands.push_back(o); } - MemRefType memrefTy = tys[4].cast(); + MemRefType memrefTy = llvm::cast(tys[4]); tys[4] = MemRefType::get( std::vector(memrefTy.getRank(), ShapedType::kDynamic), memrefTy.getElementType(), memrefTy.getLayout(), @@ -635,8 +637,8 @@ LogicalResult lowerNdMemcpy(Operation *op, PatternRewriter &rewriter, operands.push_back(nullV); } - MemRefType dstMemRefTy = dmaOp.getDst().getType().cast(); - MemRefType srcMemRefTy = dmaOp.getSrc().getType().cast(); + MemRefType dstMemRefTy = llvm::cast(dmaOp.getDst().getType()); + MemRefType srcMemRefTy = llvm::cast(dmaOp.getSrc().getType()); for (auto o : op->getOperands()) operands.push_back(o); @@ -738,7 +740,7 @@ class L1DeallocOpConversion : public OpConversionPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -757,7 +759,7 @@ class L1AffineStoreOpConversion matchAndRewrite(affine::AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -774,7 +776,7 @@ class L1MemRefLoadOpConversion : public OpConversionPattern { matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -794,7 +796,7 @@ class L1MemRefStoreOpConversion : public OpConversionPattern { matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -812,7 +814,7 @@ class L1AffineLoadOpConversion matchAndRewrite(affine::AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -837,7 +839,7 @@ class L2AllocOpConversion : public OpRewritePattern { auto ctx = op->getContext(); - auto memrefTy = op.getType().cast(); + auto memrefTy = llvm::cast(op.getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L2) return failure(); @@ -889,7 +891,7 @@ class L2DeallocOpConversion SmallVector retTys; auto ctx = op->getContext(); - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L2) return failure(); @@ -1152,12 +1154,12 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { converter.addConversion([&](Type type) -> std::optional { // convert L1 memrefs to L3 - if (auto memref = type.dyn_cast()) + if (auto memref = llvm::dyn_cast(type)) if (memref.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L1) return mlir::MemRefType::get(memref.getShape(), memref.getElementType(), memref.getLayout(), 0); - if (auto t = type.dyn_cast()) + if (auto t = llvm::dyn_cast(type)) return LLVM::LLVMPointerType::get(context); return std::optional(type); }); @@ -1196,37 +1198,32 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { }); target.addDynamicallyLegalOp([&](memref::DeallocOp op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() == - 0); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() == 0); }); target.addDynamicallyLegalOp( [&](affine::AffineStoreOp op) { - return (op.getMemref() - .getType() - .cast() + return (llvm::cast(op.getMemref().getType()) .getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1); }); - target.addDynamicallyLegalOp([&](affine::AffineLoadOp - op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() != - (int)xilinx::air::MemorySpace::L1); - }); + target.addDynamicallyLegalOp( + [&](affine::AffineLoadOp op) { + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() != + (int)xilinx::air::MemorySpace::L1); + }); target.addDynamicallyLegalOp([&](memref::StoreOp op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() != - (int)xilinx::air::MemorySpace::L1); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1); }); target.addDynamicallyLegalOp([&](memref::LoadOp op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() != - (int)xilinx::air::MemorySpace::L1); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { @@ -1235,12 +1232,12 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](func::CallOp op) { for (auto t : op.getOperandTypes()) { - if (auto mty = t.dyn_cast()) + if (auto mty = llvm::dyn_cast(t)) if (mty.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L1) return false; } for (auto t : op.getResultTypes()) { - if (auto mty = t.dyn_cast()) + if (auto mty = llvm::dyn_cast(t)) if (mty.getMemorySpaceAsInt() == (int)xilinx::air::MemorySpace::L1) return false; } @@ -1249,7 +1246,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](scf::ForOp op) { for (auto o : op.getRegionIterArgs()) { - if (o.getType().isa()) + if (llvm::isa(o.getType())) return false; } return true; @@ -1257,7 +1254,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](scf::ParallelOp op) { for (auto o : op.getInitVals()) { - if (o.getType().isa()) + if (llvm::isa(o.getType())) return false; } return true; @@ -1265,7 +1262,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](scf::YieldOp op) { for (auto v : op.getOperands()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; @@ -1273,7 +1270,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](scf::ReduceOp op) { for (auto oper : op.getOperands()) { - if (oper.getType().isa()) + if (llvm::isa(oper.getType())) return false; } return true; @@ -1281,7 +1278,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp( [&](scf::ReduceReturnOp op) { - if (op.getResult().getType().isa()) + if (llvm::isa(op.getResult().getType())) return false; else return true; @@ -1289,7 +1286,7 @@ class AIRRtToLLVM : public impl::AIRRtToLLVMBase { target.addDynamicallyLegalOp([&](scf::IfOp op) { for (auto v : op.getResults()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; diff --git a/mlir/lib/Conversion/AIRRtToNpuPass.cpp b/mlir/lib/Conversion/AIRRtToNpuPass.cpp index 384a43dd9..2f985e78e 100644 --- a/mlir/lib/Conversion/AIRRtToNpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToNpuPass.cpp @@ -262,7 +262,7 @@ class L1AffineStoreOpConversion matchAndRewrite(affine::AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -279,7 +279,7 @@ class L1MemRefStoreOpConversion : public OpConversionPattern { matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1) return failure(); @@ -587,10 +587,9 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { builder.create( loc, builder.getI64Type(), IntegerAttr::get(builder.getI64Type(), outer_wrap))); - auto new_const_stride = - (const_stride * inner_wrap) % - air::getTensorVolume( - memcpy_op.getMemref().getType().cast()); + auto new_const_stride = (const_stride * inner_wrap) % + air::getTensorVolume(llvm::cast( + memcpy_op.getMemref().getType())); strides.insert( strides.begin() + i, builder.create( @@ -915,18 +914,15 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { [&](affine::AffineStoreOp op) { if (op->getParentOfType()) return true; - return (op.getMemref() - .getType() - .cast() + return (llvm::cast(op.getMemref().getType()) .getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1); }); target.addDynamicallyLegalOp([&](memref::StoreOp op) { if (op->getParentOfType()) return true; - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() != - (int)xilinx::air::MemorySpace::L1); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() != (int)xilinx::air::MemorySpace::L1); }); target.addDynamicallyLegalOp([&](memref::CopyOp op) { auto f = op->getParentOfType(); diff --git a/mlir/lib/Conversion/AIRToAIEPass.cpp b/mlir/lib/Conversion/AIRToAIEPass.cpp index 865dd4fb0..5764536eb 100644 --- a/mlir/lib/Conversion/AIRToAIEPass.cpp +++ b/mlir/lib/Conversion/AIRToAIEPass.cpp @@ -60,7 +60,7 @@ struct AIRToAIEConversionOptions { // get memcpy operation volumn (elements) as int int getMemcpySizesAsInt(Value memref, SmallVector sizes) { - MemRefType memTy = memref.getType().cast(); + MemRefType memTy = llvm::cast(memref.getType()); if (sizes.empty()) return getTensorVolume(memTy); else { @@ -314,7 +314,7 @@ void outlineAIECores(OpBuilder &builder, AIE::DeviceOp aie_device, for (unsigned i = 0; i < h.getNumKernelOperands(); i++) { auto a = h.getKernelArgument(i); - auto memrefTy = a.getType().dyn_cast(); + auto memrefTy = llvm::dyn_cast(a.getType()); if (!memrefTy) continue; @@ -509,7 +509,7 @@ void createAIEModulesAndOutlineCores( auto oper = h.getKernelOperand(i); if (!oper.getDefiningOp()) continue; - auto memrefTy = oper.getType().dyn_cast(); + auto memrefTy = llvm::dyn_cast(oper.getType()); if (!memrefTy) continue; if (memrefTy.getMemorySpaceAsInt() != (int)air::MemorySpace::L1) @@ -706,7 +706,7 @@ struct LowerScfTokenPattern : public OpRewritePattern { Value v = fop.getOperand(block_arg.getArgNumber() - fop.getNumInductionVars() + fop.getNumControlOperands()); - if (v.getType().isa()) { + if (llvm::isa(v.getType())) { block_arg.replaceAllUsesWith(v); iter_args_idx.set(block_arg.getArgNumber()); } else { @@ -739,7 +739,7 @@ struct LowerScfTokenPattern : public OpRewritePattern { // use the new for op's results int idx = 0; for (auto r : fop.getResults()) { - if (r.getType().isa()) + if (llvm::isa(r.getType())) r.replaceAllUsesWith( rewriter .create( @@ -758,7 +758,7 @@ struct LowerScfTokenPattern : public OpRewritePattern { SmallVector yield_operands; SmallVector token_operands; for (auto o : yield->getOperands()) { - if (o.getType().isa()) + if (llvm::isa(o.getType())) token_operands.push_back(o); else yield_operands.push_back(o); @@ -817,7 +817,8 @@ void lowerScfAirTokens(AIE::DeviceOp m) { // auto o = std::get<0>(p); // operand of put // auto r = std::get<1>(p); // result of get // // for each ranked tensor put (yielded) by the tile -// if (RankedTensorType tt = o.getType().dyn_cast()) { +// if (RankedTensorType tt = +// llvm::dyn_cast(o.getType())) { // auto memrefTy = MemRefType::get(tt.getShape(), tt.getElementType(), // {}, // (int)air::MemorySpace::L1); @@ -889,7 +890,7 @@ void lowerScfAirTokens(AIE::DeviceOp m) { // return failure(); // MemRefType memrefTy = nullptr; -// memrefTy = cast.getType().cast(); +// memrefTy = llvm::cast(cast.getType()); // if (memrefTy.getMemorySpaceAsInt() != (int)air::MemorySpace::L1) // return failure(); @@ -1072,8 +1073,8 @@ void L2MemrefToMemTileMap( std::map &memrefToMemTileMap) { std::vector allocs; m.walk([&](memref::AllocOp alloc) { - if (alloc.getMemref().getType().cast().getMemorySpaceAsInt() == - (int)air::MemorySpace::L2) { + if (llvm::cast(alloc.getMemref().getType()) + .getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { allocs.push_back(alloc); } }); @@ -1111,7 +1112,7 @@ void L2MemrefToMemTileMap( int memtile_id = 0; for (auto &bucket : memref_buckets) { for (auto bucket_elem : bucket) { - MemRefType ty = bucket_elem.getMemref().getType().cast(); + MemRefType ty = llvm::cast(bucket_elem.getMemref().getType()); auto memref_vol = getElementSizeInBytes(ty) * getTensorVolume(ty); memtileToSizeMap[memtiles[memtile_id]] -= memref_vol; memrefToMemTileMap[bucket_elem] = memtiles[memtile_id]; @@ -1199,7 +1200,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern { // check if this put is linked to a get from another channel MemRefType memref = - channelPuts[0].getMemref().getType().cast(); + llvm::cast(channelPuts[0].getMemref().getType()); int mem_space = memref.getMemorySpaceAsInt(); if (mem_space == (int)air::MemorySpace::L2) { if (linksToComplete.find(channelPuts[0].getOperation()) != @@ -1239,7 +1240,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern { consumers.push_back(consumerTile); // check if this get is linked to a put from another channel - MemRefType memref = get.getMemref().getType().cast(); + MemRefType memref = llvm::cast(get.getMemref().getType()); int mem_space = memref.getMemorySpaceAsInt(); if (mem_space == (int)air::MemorySpace::L2) { if (linksToComplete.find(get.getOperation()) != linksToComplete.end()) { @@ -1349,7 +1350,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern { template LogicalResult findChannelPutGetTile(MyOp op, Value *tile, AIE::AIEObjectFifoType *datatype) const { - MemRefType memref = op.getMemref().getType().template cast(); + MemRefType memref = llvm::cast(op.getMemref().getType()); int mem_space = memref.getMemorySpaceAsInt(); *datatype = AIE::AIEObjectFifoType::get( MemRefType::get(memref.getShape(), memref.getElementType())); @@ -1393,7 +1394,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern { AIE::ObjectFifoCreateOp objFifo, AIE::ObjectFifoPort port, llvm::SmallSet &erased_allocs) const { - MemRefType memref = op.getMemref().getType().template cast(); + MemRefType memref = cast(op.getMemref().getType()); int mem_space = memref.getMemorySpaceAsInt(); if (mem_space == (int)air::MemorySpace::L2) { // add alloc to list of ops to erase @@ -1427,7 +1428,7 @@ struct LowerAIRChannelsPattern : public OpRewritePattern { PatternRewriter &rewriter, MyOp op, AIE::ObjectFifoCreateOp objFifo, AIE::ObjectFifoPort port, llvm::SmallSet &erased_deallocs) const { - MemRefType memref = op.getMemref().getType().template cast(); + MemRefType memref = llvm::cast(op.getMemref().getType()); int mem_space = memref.getMemorySpaceAsInt(); if (mem_space == (int)air::MemorySpace::L2) { return; @@ -1466,7 +1467,7 @@ void lowerAIRChannels( // Get owner (scf.parallelop) of channel indices scf::ParallelOp getChannelIndicesOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return scf::ParallelOp(); if (!ivArg.getOwner()) { @@ -1823,7 +1824,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { } // Substituting index operands, such as strides and offsets, to constant // zero for convenience. TODO: generalize this - else if (operand.getType().isa()) { + else if (llvm::isa(operand.getType())) { remap.map(operand, builder.create( builder.getUnknownLoc(), 0)); } @@ -1847,7 +1848,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { auto cloned_alloc = builder.clone(*memalloc, remap); clearAsyncDependenciesOfAsyncOp(cloned_alloc); } else { - MemRefType ty = memref.getType().cast(); + MemRefType ty = llvm::cast(memref.getType()); auto alloc_op = builder.create( builder.getUnknownLoc(), MemRefType::get(ty.getShape(), ty.getElementType(), @@ -1982,7 +1983,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { for (auto key : keys) { auto memref = chanOpPartitions[key][0].getMemref(); auto allocOp = memref.getDefiningOp(); - MemRefType ty = memref.getType().cast(); + MemRefType ty = llvm::cast(memref.getType()); SmallVector newMemrefShape; for (unsigned i = 0; i < air::getTensorShape(ty).size(); i++) { newMemrefShape.push_back(air::getTensorShape(ty)[i]); @@ -2048,7 +2049,7 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { std::vector memrefs; d.walk([&](memref::AllocOp allocOp) { auto memref = allocOp.getMemref(); - auto memrefTy = memref.getType().cast(); + auto memrefTy = llvm::cast(memref.getType()); if (memrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { // Count the number of unique incoming and outgoing channels. std::vector uniqueS2MMChannels; @@ -2405,14 +2406,15 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { if (auto tile_side_dmamemcpy = dyn_cast( tile_side_memcpy.getOperation())) { if (isMM2S) - memref_ty = - tile_side_memcpy.getDstMemref().getType().cast(); + memref_ty = llvm::cast( + tile_side_memcpy.getDstMemref().getType()); else - memref_ty = - tile_side_memcpy.getSrcMemref().getType().cast(); + memref_ty = llvm::cast( + tile_side_memcpy.getSrcMemref().getType()); } else if (auto tile_side_chan = dyn_cast( tile_side_memcpy.getOperation())) { - memref_ty = tile_side_chan.getMemref().getType().cast(); + memref_ty = + llvm::cast(tile_side_chan.getMemref().getType()); } builder.create(builder.getUnknownLoc(), dma_name, @@ -2477,14 +2479,15 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase { if (auto tile_side_dmamemcpy = dyn_cast( tile_side_memcpy.getOperation())) { if (isMM2S) - memref_ty = - tile_side_memcpy.getDstMemref().getType().cast(); + memref_ty = llvm::cast( + tile_side_memcpy.getDstMemref().getType()); else - memref_ty = - tile_side_memcpy.getSrcMemref().getType().cast(); + memref_ty = llvm::cast( + tile_side_memcpy.getSrcMemref().getType()); } else if (auto tile_side_chan = dyn_cast( tile_side_memcpy.getOperation())) { - memref_ty = tile_side_chan.getMemref().getType().cast(); + memref_ty = + llvm::cast(tile_side_chan.getMemref().getType()); } builder.create(builder.getUnknownLoc(), dma_name, diff --git a/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp b/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp index 445331e7f..a406df23c 100644 --- a/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp +++ b/mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp @@ -21,14 +21,12 @@ using namespace xilinx; bool air::isTileInbound(air::MemcpyInterface memcpyOp, int tileMemSpaceAsInt) { if (memcpyOp.getSrcMemref() && memcpyOp.getDstMemref()) { - int src_memory_space = memcpyOp.getSrcMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); - int dst_memory_space = memcpyOp.getDstMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); + int src_memory_space = + llvm::cast(memcpyOp.getSrcMemref().getType()) + .getMemorySpaceAsInt(); + int dst_memory_space = + llvm::cast(memcpyOp.getDstMemref().getType()) + .getMemorySpaceAsInt(); assert(src_memory_space != dst_memory_space); // air.dmaMemcpyNdOp isn't meant to represent // core-to-core communication @@ -225,7 +223,7 @@ std::pair air::getLockValuePair(AIE::AIEArch arch, // Infer semaphore lock values using buffer op // TODO: What if a buffer memref is read or written by multiple channels? - if (!buffer_memref.getType().isa()) + if (!llvm::isa(buffer_memref.getType())) return std::make_pair(-1, -1); int read_counter = 0; int write_counter = 0; @@ -264,7 +262,7 @@ std::pair air::getLockValuePair(AIE::AIEArch arch, bool isAIE2 = (arch == AIE::AIEArch::AIE2); if (!isAIE2) return std::make_pair(0, 0); - if (!buffer_memref.getType().isa()) + if (!llvm::isa(buffer_memref.getType())) return std::make_pair(-1, -1); if (!air_chan) @@ -653,7 +651,7 @@ ShimDMAAllocator::getBuffer(uint64_t &BufferId, int64_t col, int64_t row, auto memref = (isMM2S) ? (memcpyOp.getSrcMemref()) : (memcpyOp.getDstMemref()); assert(memref); - MemRefType memrefTy = memref.getType().cast(); + MemRefType memrefTy = llvm::cast(memref.getType()); // External buffers have memory space L3 memrefTy = MemRefType::get(memrefTy.getShape(), memrefTy.getElementType(), {}, DMAMemorySpaceAsInt); @@ -836,15 +834,13 @@ AIE::BufferOp MemTileDMAAllocator::getBuffer(uint64_t, int64_t col, int64_t row, void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::DmaMemcpyNdOp memcpyOp) { // air::DmaMemcpyNdOp is a complete memcpy with both src and dst S2MM[0].push_back(memcpyOp.getOperation()); - S2MM_memspace_as_int = memcpyOp.getDstMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); + S2MM_memspace_as_int = + llvm::cast(memcpyOp.getDstMemref().getType()) + .getMemorySpaceAsInt(); MM2S.push_back(memcpyOp.getOperation()); - MM2S_memspace_as_int = memcpyOp.getSrcMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); + MM2S_memspace_as_int = + llvm::cast(memcpyOp.getSrcMemref().getType()) + .getMemorySpaceAsInt(); } void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp) { @@ -875,16 +871,16 @@ void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelGetOp memcpyOp) { } air_flow_op = chan.getOperation(); S2MM[alloc_id].push_back(memcpyOp.getOperation()); - S2MM_memspace_as_int = - memcpyOp.getMemref().getType().cast().getMemorySpaceAsInt(); + S2MM_memspace_as_int = llvm::cast(memcpyOp.getMemref().getType()) + .getMemorySpaceAsInt(); } void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle(air::ChannelPutOp memcpyOp) { auto chan = air::getChannelDeclarationThroughSymbol(memcpyOp); air_flow_op = chan.getOperation(); MM2S.push_back(memcpyOp.getOperation()); - MM2S_memspace_as_int = - memcpyOp.getMemref().getType().cast().getMemorySpaceAsInt(); + MM2S_memspace_as_int = llvm::cast(memcpyOp.getMemref().getType()) + .getMemorySpaceAsInt(); } void MemcpyBundleAsFlow::pushBackMemcpyOpToBundle( diff --git a/mlir/lib/Conversion/AIRToAsyncPass.cpp b/mlir/lib/Conversion/AIRToAsyncPass.cpp index ceeff2c4a..3b549a493 100644 --- a/mlir/lib/Conversion/AIRToAsyncPass.cpp +++ b/mlir/lib/Conversion/AIRToAsyncPass.cpp @@ -151,14 +151,14 @@ static func::CallOp convertOpToFunction(Operation *op, ArrayRef operands, SmallVector dependencies; for (auto o : operands) { // erase the size to reduce the number of manglings - if (auto memrefTy = o.getType().dyn_cast()) { + if (auto memrefTy = llvm::dyn_cast(o.getType())) { auto t = MemRefType::get( std::vector(memrefTy.getRank(), ShapedType::kDynamic), memrefTy.getElementType(), memrefTy.getLayout(), /*memrefTy.getMemorySpace()*/ 0); callops.push_back( rewriter.create(loc, t, o).getResult(0)); - } else if (o.getType().isa()) { + } else if (llvm::isa(o.getType())) { dependencies.push_back(o); } else { callops.push_back(o); @@ -168,14 +168,14 @@ static func::CallOp convertOpToFunction(Operation *op, ArrayRef operands, SmallVector real_result_tys; SmallVector token_result_tys; for (auto t : op->getResultTypes()) { - if (auto memrefTy = t.dyn_cast()) { + if (auto memrefTy = llvm::dyn_cast(t)) { auto mrt = MemRefType::get( std::vector(memrefTy.getRank(), ShapedType::kDynamic), memrefTy.getElementType(), memrefTy.getLayout(), /*memrefTy.getMemorySpace()*/ 0); retTys.push_back(mrt); real_result_tys.push_back(memrefTy); - } else if (t.isa()) { + } else if (llvm::isa(t)) { token_result_tys.push_back(t); } else { retTys.push_back(t); @@ -204,7 +204,7 @@ static func::CallOp convertOpToFunction(Operation *op, ArrayRef operands, results = call.getResults(); for (unsigned i = 0, real_result_idx = 0; i < results.size(); ++i) { auto r = results[i]; - if (auto memrefTy = r.getType().dyn_cast()) { + if (auto memrefTy = llvm::dyn_cast(r.getType())) { auto t = real_result_tys[real_result_idx++]; auto c = rewriter.create(op->getLoc(), t, r); @@ -320,7 +320,7 @@ class DeallocOpConversion : public OpConversionPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto memrefTy = op.getMemref().getType().cast(); + auto memrefTy = llvm::cast(op.getMemref().getType()); if (memrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3) return failure(); @@ -341,7 +341,7 @@ class AsyncCallOpConversion : public OpConversionPattern { SmallVector retTy; for (auto t : op.getResultTypes()) - if (t.isa()) + if (llvm::isa(t)) retTy.push_back(async::TokenType::get(op->getContext())); else retTy.push_back(t); @@ -382,7 +382,7 @@ class WaitAllOpConversion : public OpConversionPattern { for (auto o : operands) { Value v = o; - // if (o.getType().isa()) + // if (llvm::isa(o.getType())) // v = rewriter.create(op->getLoc(), // async::TokenType::get(op->getContext()), // o).getResult(0); @@ -403,7 +403,7 @@ class ScfYieldOpConversion : public OpConversionPattern { SmallVector operands{adaptor.getOperands()}; SmallVector retTys; for (auto t : op->getResultTypes()) { - if (t.isa()) { + if (llvm::isa(t)) { retTys.push_back(async::TokenType::get(op->getContext())); } else { retTys.push_back(t); @@ -596,7 +596,7 @@ struct ChannelOpConversion : public OpConversionPattern { SmallVector shape; for (auto i : op.getSize()) { - shape.push_back(i.dyn_cast().getInt()); + shape.push_back(llvm::dyn_cast(i).getInt()); } // if channel dim < 2, add until dim = 2 while (shape.size() < 2) { @@ -678,9 +678,10 @@ class ChannelPutOpConversion : public OpConversionPattern { } // if channel is broadcast, add broadcast shape if (channelOp->getAttr("broadcast_shape")) { - for (auto i : channelOp->getAttr("broadcast_shape").cast()) { + for (auto i : + llvm::cast(channelOp->getAttr("broadcast_shape"))) { operands.push_back(rewriter.create( - op->getLoc(), i.cast().getInt())); + op->getLoc(), llvm::cast(i).getInt())); } } else { // if channel is not broadcast, add 1 @@ -717,9 +718,9 @@ class AIRToAsyncPass : public air::impl::AIRToAsyncBase { TypeConverter converter; converter.addConversion([&](Type type) -> std::optional { // convert air::AsyncTokenType to async::TokenType - if (auto t = type.dyn_cast()) + if (auto t = llvm::dyn_cast(type)) return async::TokenType::get(context); - if (auto t = type.dyn_cast()) + if (auto t = llvm::dyn_cast(type)) if (t.getMemorySpaceAsInt() != 0) return MemRefType::get(t.getShape(), t.getElementType(), t.getLayout(), 0); @@ -768,9 +769,9 @@ class AIRToAsyncPass : public air::impl::AIRToAsyncBase { target.addDynamicallyLegalOp([&](func::CallOp op) { auto isIllegal = [](Type t) { - if (t.isa()) + if (llvm::isa(t)) return true; - if (auto mt = t.dyn_cast()) + if (auto mt = llvm::dyn_cast(t)) return mt.getMemorySpaceAsInt() != 0; return false; }; @@ -780,7 +781,7 @@ class AIRToAsyncPass : public air::impl::AIRToAsyncBase { target.addDynamicallyLegalOp([&](scf::ForOp op) { for (auto o : op.getRegionIterArgs()) { - if (o.getType().isa()) + if (llvm::isa(o.getType())) return false; } return true; @@ -791,7 +792,7 @@ class AIRToAsyncPass : public air::impl::AIRToAsyncBase { target.addDynamicallyLegalOp([&](scf::YieldOp op) { for (auto v : op.getResults()) { - if (v.getType().isa()) + if (llvm::isa(v.getType())) return false; } return true; @@ -802,9 +803,8 @@ class AIRToAsyncPass : public air::impl::AIRToAsyncBase { }); target.addDynamicallyLegalOp([&](memref::DeallocOp op) { - return ( - op.getMemref().getType().cast().getMemorySpaceAsInt() == - 0); + return (llvm::cast(op.getMemref().getType()) + .getMemorySpaceAsInt() == 0); }); RewritePatternSet typeConversionPatterns(context); diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index f616a0992..d831199db 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -64,8 +64,8 @@ matchAndRewriteCopyOp(memref::CopyOp op, RewriterBase &rewriter) { rewriter.setInsertionPoint(op); // It must already be a memref - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); if (!src_type) return failure(); @@ -91,10 +91,10 @@ matchAndRewriteCopyOp(memref::CopyOp op, RewriterBase &rewriter) { auto loc = subview.getLoc(); // get the strides and offsets from the memref type - auto inferredType = memref::SubViewOp::inferResultType( - subview.getSourceType(), static_offsets, - static_sizes, static_strides) - .cast(); + auto inferredType = + llvm::cast(memref::SubViewOp::inferResultType( + subview.getSourceType(), static_offsets, static_sizes, + static_strides)); int64_t offset; SmallVector layout_strides; auto successStrides = @@ -163,10 +163,8 @@ static void extractOperandsFromSubview(memref::SubViewOp subview, auto loc = subview.getLoc(); // get the strides and offsets from the memref type - auto inferredType = - memref::SubViewOp::inferResultType( - subview.getSourceType(), static_offsets, static_sizes, static_strides) - .cast(); + auto inferredType = llvm::cast(memref::SubViewOp::inferResultType( + subview.getSourceType(), static_offsets, static_sizes, static_strides)); int64_t offset; SmallVector layout_strides; auto successStrides = @@ -595,8 +593,8 @@ class LinalgCopyToAIRDmaConversion : public OpRewritePattern { auto dst = op.getOutputs()[0]; // It must already be a memref - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); if (!src_type) return failure(); @@ -722,8 +720,8 @@ void replaceAIRDmaWithAIRChannelPairs( auto dst = op.getDstMemref(); auto ctx = op->getContext(); - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); SmallVector src_offsets = op.getSrcOffsets(); SmallVector dst_offsets = op.getDstOffsets(); SmallVector src_sizes = op.getSrcSizes(); @@ -947,7 +945,7 @@ void HoistingAffineIf(affine::AffineIfOp op) { // Hoist hierarchy op into scf op module_builder.setInsertionPoint(hier_op); MemRefType externalMemrefTy = - externalGetPut[0].getMemref().getType().cast(); + llvm::cast(externalGetPut[0].getMemref().getType()); if (externalMemrefTy.getMemorySpaceAsInt() == (int)air::MemorySpace::L3 && segment) { module_builder.setInsertionPoint(segment); @@ -1055,8 +1053,8 @@ class AIRDmaToAIRChannelConversion auto ctx = op->getContext(); // It must already be a memref - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); if (!src_type) return failure(); @@ -1363,7 +1361,7 @@ LogicalResult AIRDemoteMemrefToAIRHierarchy( auto memref = isa(op) ? op->getResult(1) : op->getResult(0); auto token = isa(op) ? op->getResult(0) : nullptr; - auto memref_type = memref.getType().dyn_cast(); + auto memref_type = llvm::dyn_cast(memref.getType()); if (memref_type.getMemorySpaceAsInt() == hierMemorySpace) continue; // Alloc op is already under correct hierarchy @@ -1447,8 +1445,8 @@ class AIRDemoteDmaToAIRHierarchyConversion auto ctx = op->getContext(); // It must already be a memref - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); if (!src_type) return failure(); @@ -2253,9 +2251,9 @@ LogicalResult TileL1L2AIRMemcpyUsingScfParallel(air::DmaMemcpyNdOp op, for (unsigned i = 0; i < L2MemrefShape.size(); i++) L2Sizes.push_back(builder.getIndexAttr(L2TiledShape[i])); auto subviewOutputType = - memref::SubViewOp::inferResultType(L2Memref.getType().cast(), - L2Offsets, L2Sizes, L2Strides) - .cast(); + llvm::cast(memref::SubViewOp::inferResultType( + llvm::cast(L2Memref.getType()), L2Offsets, L2Sizes, + L2Strides)); auto newL2Subview = builder.create( loc, subviewOutputType, L2Memref, L2Offsets, L2Sizes, L2Strides); remap.map(L2Memref, newL2Subview.getResult()); @@ -2650,8 +2648,8 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( auto loc = dmaOp->getLoc(); // It must already be a memref - auto src_type = src.getType().dyn_cast(); - auto dst_type = dst.getType().dyn_cast(); + auto src_type = llvm::dyn_cast(src.getType()); + auto dst_type = llvm::dyn_cast(dst.getType()); if (!src_type) return failure(); if (!(src_type.hasStaticShape() || dst_type.hasStaticShape())) @@ -2675,7 +2673,7 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( src = subviewOp.getSource(); } else if (auto transposeOp = dyn_cast(src_ancestor_memref_ops[0])) { - src_memref_ty = transposeOp.getIn().getType().cast(); + src_memref_ty = llvm::cast(transposeOp.getIn().getType()); src = transposeOp.getIn(); } } @@ -2688,7 +2686,7 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( dst = subviewOp.getSource(); } else if (auto transposeOp = dyn_cast(dst_ancestor_memref_ops[0])) { - dst_memref_ty = transposeOp.getIn().getType().cast(); + dst_memref_ty = llvm::cast(transposeOp.getIn().getType()); dst = transposeOp.getIn(); } } @@ -2710,18 +2708,16 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( } else if (auto subviewOp = dyn_cast(memrefOp)) { // Check if subview is rank reduced if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) - src_memref_ty = + src_memref_ty = llvm::cast( memref::SubViewOp::inferRankReducedResultType( subviewOp.getType().getShape(), src_memref_ty, subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), - subviewOp.getStaticStrides()) - .cast(); + subviewOp.getStaticStrides())); else src_memref_ty = - memref::SubViewOp::inferResultType( + llvm::cast(memref::SubViewOp::inferResultType( src_memref_ty, subviewOp.getStaticOffsets(), - subviewOp.getStaticSizes(), subviewOp.getStaticStrides()) - .cast(); + subviewOp.getStaticSizes(), subviewOp.getStaticStrides())); } } @@ -2741,18 +2737,16 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( } } else if (auto subviewOp = dyn_cast(memrefOp)) { if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) - dst_memref_ty = + dst_memref_ty = llvm::cast( memref::SubViewOp::inferRankReducedResultType( subviewOp.getType().getShape(), dst_memref_ty, subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), - subviewOp.getStaticStrides()) - .cast(); + subviewOp.getStaticStrides())); else dst_memref_ty = - memref::SubViewOp::inferResultType( + llvm::cast(memref::SubViewOp::inferResultType( dst_memref_ty, subviewOp.getStaticOffsets(), - subviewOp.getStaticSizes(), subviewOp.getStaticStrides()) - .cast(); + subviewOp.getStaticSizes(), subviewOp.getStaticStrides())); } } @@ -2768,12 +2762,12 @@ static LogicalResult condenseMemrefDataReorderingToAIRDma( SmallVector deps; SmallVector tys; - if (failed(canonicalizeAIRDmaOperands(rewriter, src_offsets, src_sizes, - src_strides, - src.getType().cast())) || - failed(canonicalizeAIRDmaOperands(rewriter, dst_offsets, dst_sizes, - dst_strides, - dst.getType().cast()))) { + if (failed(canonicalizeAIRDmaOperands( + rewriter, src_offsets, src_sizes, src_strides, + llvm::cast(src.getType()))) || + failed(canonicalizeAIRDmaOperands( + rewriter, dst_offsets, dst_sizes, dst_strides, + llvm::cast(dst.getType())))) { assert(false); } auto new_dma = rewriter.create( @@ -2816,8 +2810,8 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase { affine::AffineYieldOp>(); target.addDynamicallyLegalOp([](memref::CopyOp co) { - auto src_type = co.getSource().getType().dyn_cast(); - auto dst_type = co.getTarget().getType().dyn_cast(); + auto src_type = llvm::dyn_cast(co.getSource().getType()); + auto dst_type = llvm::dyn_cast(co.getTarget().getType()); return src_type.getMemorySpaceAsInt() == dst_type.getMemorySpaceAsInt(); }); @@ -3005,8 +2999,10 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase { target_0.addDynamicallyLegalOp( [&](air::DmaMemcpyNdOp dma) { - auto src_type = dma.getSrcMemref().getType().dyn_cast(); - auto dst_type = dma.getDstMemref().getType().dyn_cast(); + auto src_type = + llvm::dyn_cast(dma.getSrcMemref().getType()); + auto dst_type = + llvm::dyn_cast(dma.getDstMemref().getType()); if (dma->getParentOfType()) { if (src_type.getMemorySpaceAsInt() < (int)air::MemorySpace::L1 && dst_type.getMemorySpaceAsInt() < (int)air::MemorySpace::L1) @@ -3179,11 +3175,11 @@ static void getHerdNames(ModuleOp module) { continue; if (!isa(operJ.getType())) continue; - if (operI.getType().cast().getMemorySpaceAsInt() != - (int)air::MemorySpace::L1) + if (llvm::cast(operI.getType()) + .getMemorySpaceAsInt() != (int)air::MemorySpace::L1) continue; - if (operJ.getType().cast().getMemorySpaceAsInt() != - (int)air::MemorySpace::L1) + if (llvm::cast(operJ.getType()) + .getMemorySpaceAsInt() != (int)air::MemorySpace::L1) continue; if (operI != operJ) continue; @@ -3288,8 +3284,8 @@ struct ParallelToHerdPass // Ensure that air.dma_memcpy_nd ops between L1 and L2 are within at least // two parent scf.parallel loops. module.walk([&](air::DmaMemcpyNdOp op) { - auto srcMemrefTy = op.getSrcMemref().getType().cast(); - auto dstMemrefTy = op.getDstMemref().getType().cast(); + auto srcMemrefTy = llvm::cast(op.getSrcMemref().getType()); + auto dstMemrefTy = llvm::cast(op.getDstMemref().getType()); Value L1Memref = nullptr; Value L2Memref = nullptr; bool SrcIsL1 = false; diff --git a/mlir/lib/Dialect/AIR/TransformOps/AIRTransformOps.cpp b/mlir/lib/Dialect/AIR/TransformOps/AIRTransformOps.cpp index d1455c0f6..57b723d09 100644 --- a/mlir/lib/Dialect/AIR/TransformOps/AIRTransformOps.cpp +++ b/mlir/lib/Dialect/AIR/TransformOps/AIRTransformOps.cpp @@ -38,7 +38,7 @@ transform::GetSegmentForOp::apply(transform::TransformRewriter &rewriter, } segments.insert(segment); } - results.set(getResult().cast(), segments.getArrayRef()); + results.set(llvm::cast(getResult()), segments.getArrayRef()); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Targets/AIRTargets.cpp b/mlir/lib/Targets/AIRTargets.cpp index 008632a90..6e4e59a79 100644 --- a/mlir/lib/Targets/AIRTargets.cpp +++ b/mlir/lib/Targets/AIRTargets.cpp @@ -55,14 +55,14 @@ static llvm::cl::opt llvm::cl::init(0)); llvm::json::Value attrToJSON(Attribute &attr) { - if (auto a = attr.dyn_cast()) { + if (auto a = llvm::dyn_cast(attr)) { return llvm::json::Value(a.getValue().str()); - } else if (auto array_attr = attr.dyn_cast()) { + } else if (auto array_attr = llvm::dyn_cast(attr)) { llvm::json::Array arrayJSON; for (auto a : array_attr) arrayJSON.push_back(attrToJSON(a)); return llvm::json::Value(std::move(arrayJSON)); - } else if (auto dict_attr = attr.dyn_cast()) { + } else if (auto dict_attr = llvm::dyn_cast(attr)) { llvm::json::Object dictJSON; for (auto a : dict_attr) { auto ident = a.getName(); @@ -70,7 +70,7 @@ llvm::json::Value attrToJSON(Attribute &attr) { dictJSON[ident.str()] = attrToJSON(attr); } return llvm::json::Value(std::move(dictJSON)); - } else if (auto int_attr = attr.dyn_cast()) { + } else if (auto int_attr = llvm::dyn_cast(attr)) { return llvm::json::Value(int_attr.getInt()); } else return llvm::json::Value(std::string("")); diff --git a/mlir/lib/Transform/AIRDependency.cpp b/mlir/lib/Transform/AIRDependency.cpp index 3ee5cad9b..78dd2aab3 100644 --- a/mlir/lib/Transform/AIRDependency.cpp +++ b/mlir/lib/Transform/AIRDependency.cpp @@ -167,7 +167,7 @@ class AIRDependency // Create async execute region for arith.muli else if (auto arith_op = dyn_cast(op)) { - if (arith_op.getResult().getType().isa()) { + if (llvm::isa(arith_op.getResult().getType())) { createAsyncExecute(module_builder, op, "arith::muli", ExecuteOpID, arith_op.getResult().getType()); } @@ -175,7 +175,7 @@ class AIRDependency // Create async execute region for arith.addi else if (auto arith_op = dyn_cast(op)) { - if (arith_op.getResult().getType().isa()) { + if (llvm::isa(arith_op.getResult().getType())) { createAsyncExecute(module_builder, op, "arith::addi", ExecuteOpID, arith_op.getResult().getType()); } @@ -198,8 +198,8 @@ class AIRDependency else { bool isCandidateExecute = false; for (auto operand : op->getOperands()) { - if (operand.getType().isa() || - operand.getType().isa()) { + if (llvm::isa(operand.getType()) || + llvm::isa(operand.getType())) { isCandidateExecute = true; } } @@ -260,24 +260,24 @@ class AIRDependency // If the sink op is linalg op if (auto sink_op_linalgop = dyn_cast(sink_op)) { for (auto ins_value : sink_op_linalgop.getDpsInputs()) { - if (ins_value.getType().isa()) { + if (llvm::isa(ins_value.getType())) { unsigned memRefRank = - ins_value.getType().cast().getRank(); + llvm::cast(ins_value.getType()).getRank(); partialMemref tile = createPartialMemref(ins_value, memRefRank); sink_op_memref_reads.push_back(tile); - } else if (ins_value.getType().isa()) { + } else if (llvm::isa(ins_value.getType())) { sink_op_scalar_ins.push_back(ins_value); } } for (auto outs_value : sink_op_linalgop.getDpsInits()) { - if (outs_value.getType().isa()) { + if (llvm::isa(outs_value.getType())) { unsigned memRefRank = - outs_value.getType().cast().getRank(); + llvm::cast(outs_value.getType()).getRank(); partialMemref tile = createPartialMemref(outs_value, memRefRank); sink_op_memref_reads.push_back( tile); // linalg op both reads and writes the output memref sink_op_memref_writes.push_back(tile); - } else if (outs_value.getType().isa()) { + } else if (llvm::isa(outs_value.getType())) { sink_op_scalar_ins.push_back( outs_value); // linalg op both reads and writes the output // memref @@ -286,13 +286,13 @@ class AIRDependency } if (sink_op_linalgop->getNumResults()) { for (auto linalg_results : sink_op_linalgop->getResults()) { - if (linalg_results.getType().isa()) { + if (llvm::isa(linalg_results.getType())) { unsigned memRefRank = - linalg_results.getType().cast().getRank(); + llvm::cast(linalg_results.getType()).getRank(); partialMemref tile = createPartialMemref(linalg_results, memRefRank); sink_op_memref_writes.push_back(tile); - } else if (linalg_results.getType().isa()) { + } else if (llvm::isa(linalg_results.getType())) { sink_op_scalar_outs.push_back(linalg_results); } } @@ -302,10 +302,9 @@ class AIRDependency // If the sink op is memref::dealloc else if (auto sink_op_memdealloc = dyn_cast(sink_op)) { - unsigned memRefRank = sink_op_memdealloc.getMemref() - .getType() - .cast() - .getRank(); + unsigned memRefRank = + llvm::cast(sink_op_memdealloc.getMemref().getType()) + .getRank(); partialMemref tile = createPartialMemref(sink_op_memdealloc.getMemref(), memRefRank); sink_op_memref_reads.push_back(tile); @@ -315,17 +314,15 @@ class AIRDependency // If the sink op is memref::copy else if (auto sink_op_memref_copy = dyn_cast(sink_op)) { - unsigned memRefRankSrc = sink_op_memref_copy.getSource() - .getType() - .cast() - .getRank(); + unsigned memRefRankSrc = + llvm::cast(sink_op_memref_copy.getSource().getType()) + .getRank(); partialMemref tileSrc = createPartialMemref( sink_op_memref_copy.getSource(), memRefRankSrc); sink_op_memref_reads.push_back(tileSrc); - unsigned memRefRankDst = sink_op_memref_copy.getTarget() - .getType() - .cast() - .getRank(); + unsigned memRefRankDst = + llvm::cast(sink_op_memref_copy.getTarget().getType()) + .getRank(); partialMemref tileDst = createPartialMemref( sink_op_memref_copy.getTarget(), memRefRankDst); sink_op_memref_reads.push_back(tileDst); @@ -337,10 +334,9 @@ class AIRDependency mlir::dyn_cast(sink_op)) { if (sink_op_memcpy.getSrcMemref()) { SmallVector src_indices; - unsigned numDimsSrc = sink_op_memcpy.getSrcMemref() - .getType() - .cast() - .getRank(); + unsigned numDimsSrc = + llvm::cast(sink_op_memcpy.getSrcMemref().getType()) + .getRank(); for (unsigned i = 0; i < sink_op_memcpy.getSrcOffsets().size(); i++) sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcOffsets()[i]); for (unsigned i = 0; i < sink_op_memcpy.getSrcSizes().size(); i++) @@ -363,10 +359,9 @@ class AIRDependency } if (sink_op_memcpy.getDstMemref()) { SmallVector dst_indices; - unsigned numDimsDst = sink_op_memcpy.getDstMemref() - .getType() - .cast() - .getRank(); + unsigned numDimsDst = + llvm::cast(sink_op_memcpy.getDstMemref().getType()) + .getRank(); for (unsigned i = 0; i < sink_op_memcpy.getDstOffsets().size(); i++) sink_op_scalar_outs.push_back(sink_op_memcpy.getDstOffsets()[i]); for (unsigned i = 0; i < sink_op_memcpy.getDstSizes().size(); i++) @@ -415,14 +410,14 @@ class AIRDependency // If the sink op is an unknown op else { for (auto sink_op_op : sink_op->getOperands()) { - if (sink_op_op.getType().isa()) { + if (llvm::isa(sink_op_op.getType())) { unsigned memRefRank = - sink_op_op.getType().cast().getRank(); + llvm::cast(sink_op_op.getType()).getRank(); partialMemref tile = createPartialMemref(sink_op_op, memRefRank); sink_op_memref_reads.push_back( tile); // Assuming all operands are both read and written to sink_op_memref_writes.push_back(tile); - } else if (sink_op_op.getType().isa()) { + } else if (llvm::isa(sink_op_op.getType())) { sink_op_scalar_ins.push_back( sink_op_op); // Assuming all operands are both read and // written to @@ -431,13 +426,13 @@ class AIRDependency } if (sink_op->getNumResults()) { for (auto sink_op_results : sink_op->getResults()) { - if (sink_op_results.getType().isa()) { + if (llvm::isa(sink_op_results.getType())) { unsigned memRefRank = - sink_op_results.getType().cast().getRank(); + llvm::cast(sink_op_results.getType()).getRank(); partialMemref tile = createPartialMemref(sink_op_results, memRefRank); sink_op_memref_writes.push_back(tile); - } else if (sink_op_results.getType().isa()) { + } else if (llvm::isa(sink_op_results.getType())) { sink_op_scalar_outs.push_back(sink_op_results); } } @@ -956,7 +951,7 @@ class AIRDependency } char checkOperandReadOrWrite(mlir::Value operand) { - if (!operand.getType().isa()) { + if (!llvm::isa(operand.getType())) { operand.getDefiningOp()->emitOpError( "operand being traced is not a memref"); } @@ -1002,7 +997,7 @@ class AIRDependency template void pushDepsAtCurrentScope(mlir::Value operand, T op, char rw = 'n', partialMemref *tile = nullptr) { - if (!operand.getType().isa()) { + if (!llvm::isa(operand.getType())) { operand.getDefiningOp()->emitOpError( "operand being traced is not a memref"); } @@ -1028,7 +1023,8 @@ class AIRDependency if (memcpy.getSrcMemref()) { SmallVector src_indices; unsigned numDimsSrc = - memcpy.getSrcMemref().getType().cast().getRank(); + llvm::cast(memcpy.getSrcMemref().getType()) + .getRank(); if (memcpy.getSrcOffsets().size()) { numDimsSrc = memcpy.getSrcOffsets().size(); for (unsigned i = 0; i < numDimsSrc; i++) { @@ -1044,7 +1040,8 @@ class AIRDependency } if (memcpy.getDstMemref()) { unsigned numDimsDst = - memcpy.getDstMemref().getType().cast().getRank(); + llvm::cast(memcpy.getDstMemref().getType()) + .getRank(); SmallVector dst_indices; if (memcpy.getDstOffsets().size()) { numDimsDst = memcpy.getDstOffsets().size(); diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 68b3dbfd1..986b8b4eb 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -563,7 +563,7 @@ struct AnnotateFrontAndBackOpsInForPattern // Check if the for loop is async SmallVector iterTokens; for (auto iter_arg : for_op.getRegionIterArgs()) { - if (iter_arg.getType().isa()) { + if (llvm::isa(iter_arg.getType())) { iterTokens.push_back(iter_arg); } } @@ -624,7 +624,7 @@ struct AnnotateFrontAndBackOpsInForPattern auto yield = for_op.getBody()->getTerminator(); SmallVector yielded_tokens; for (auto operand : yield->getOperands()) { - if (operand.getType().isa()) { + if (llvm::isa(operand.getType())) { yielded_tokens.push_back(operand); } } @@ -722,7 +722,7 @@ struct HoistMemallocInForPattern : public OpRewritePattern { // Find dealloc Operation *dealloc_op = nullptr; auto alloc_exec_memref = alloc_exec->getResults()[1]; - if (!alloc_exec_memref.getType().isa()) + if (!llvm::isa(alloc_exec_memref.getType())) alloc_op->emitOpError("the ssa value yielded from execute is not memref"); for (auto user : alloc_exec_memref.getUsers()) { if (isa(user)) { @@ -1239,7 +1239,7 @@ struct ConstructPingPongDependencyPattern Value getAsyncTokenFromValues(SmallVector vec) const { for (auto v : vec) { - if (v.getType().isa()) { + if (llvm::isa(v.getType())) { return v; } } @@ -1307,7 +1307,7 @@ struct ConstructPingPongDependencyPattern Value v, SmallVector &alloc_execs) const { if (auto exec = v.getDefiningOp()) { if (exec->hasAttr("unrolled_iteration") && exec->getNumResults() == 2 && - exec->getResult(1).getType().isa()) { + llvm::isa(exec->getResult(1).getType())) { alloc_execs.push_back(exec.getOperation()); for (auto dep : exec.getAsyncDependencies()) { pushToAllocExecsIfHoistedFromLoop(dep, alloc_execs); @@ -1359,7 +1359,7 @@ struct EnforceLoopCarriedMemrefDeallocPattern std::vector adjacent_events(Operation *event) const { SmallVector returned_tokens = {}; for (Value result : event->getResults()) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { returned_tokens.push_back(result); } } @@ -1533,7 +1533,7 @@ struct CanonicalizeAffineApplyOnLoopInductionVar if (apply.getAffineMap().getNumInputs() != 1) return failure(); auto val = apply->getOperand(0); - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return failure(); if (!ivArg.getOwner()) @@ -2051,7 +2051,7 @@ struct UnrollChannelByFactorPattern { // Update memref size (divide by factor) SmallVector new_sizes = op.getSizes(); if (new_sizes.empty()) { - auto memTy = op.getMemref().getType().template cast(); + auto memTy = llvm::cast(op.getMemref().getType()); for (auto d : getTensorShape(memTy)) { new_sizes.push_back( builder.create(par->getLoc(), d)); @@ -2117,14 +2117,12 @@ struct BroadcastDetection { void getDmaOpLoopDependency(func::FuncOp f) { f.walk([&](Operation *op) { if (auto dma_op = mlir::dyn_cast(op)) { - int src_memspace = dma_op.getSrcMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); - int dst_memspace = dma_op.getDstMemref() - .getType() - .cast() - .getMemorySpaceAsInt(); + int src_memspace = + llvm::cast(dma_op.getSrcMemref().getType()) + .getMemorySpaceAsInt(); + int dst_memspace = + llvm::cast(dma_op.getDstMemref().getType()) + .getMemorySpaceAsInt(); bool isL1Memcpy = (src_memspace == (int)air::MemorySpace::L1) || (dst_memspace == (int)air::MemorySpace::L1); if (dma_op->getParentOfType() && isL1Memcpy) { @@ -2164,9 +2162,7 @@ struct BroadcastDetection { } // If not variant wrt herd, then check for fixed row-wise or col-wise // offset. - int src_memspace = dma_op.getSrcMemref() - .getType() - .cast() + int src_memspace = llvm::cast(dma_op.getSrcMemref().getType()) .getMemorySpaceAsInt(); auto externalOffsets = src_memspace == (int)air::MemorySpace::L1 ? dma_op.getDstOffsets() @@ -3115,7 +3111,7 @@ class AIRFuseChannels bool hitsMemorySpaceForAggMode(std::vector &puts, std::vector &gets) { for (auto put : puts) { - MemRefType ty = put.getMemref().getType().cast(); + MemRefType ty = llvm::cast(put.getMemref().getType()); if (llvm::any_of(targetMemorySpaces, [&](unsigned memSpace) { return memSpace == ty.getMemorySpaceAsInt(); })) { @@ -3123,7 +3119,7 @@ class AIRFuseChannels } } for (auto get : gets) { - MemRefType ty = get.getMemref().getType().cast(); + MemRefType ty = llvm::cast(get.getMemref().getType()); if (llvm::any_of(targetMemorySpaces, [&](unsigned memSpace) { return memSpace == ty.getMemorySpaceAsInt(); })) { @@ -3252,9 +3248,9 @@ class AIRFuseChannels if (a == b) return true; auto aHierOper = - getHierOperandFromHierBlockArgument(a.dyn_cast()); + getHierOperandFromHierBlockArgument(llvm::dyn_cast(a)); auto bHierOper = - getHierOperandFromHierBlockArgument(b.dyn_cast()); + getHierOperandFromHierBlockArgument(llvm::dyn_cast(b)); if (!(aHierOper && bHierOper)) return false; if (aHierOper == bHierOper) @@ -3527,7 +3523,7 @@ class AIRFuseChannels continue; if (auto execOp = dyn_cast(o)) if (execOp->getNumResults() == 2 && - execOp->getResult(1).getType().isa()) + llvm::isa(execOp->getResult(1).getType())) continue; eventCounter++; } @@ -3545,7 +3541,7 @@ class AIRFuseChannels OpBuilder builder(targetOp); SmallVector depList; for (auto operand : targetOp->getOperands()) { - if (operand.getType().isa()) + if (llvm::isa(operand.getType())) depList.push_back(operand); } for (auto res : targetOp->getResults()) { @@ -3792,9 +3788,9 @@ struct ShrinkMemrefSizesByAccessPattern } // Replace memref alloc op; - Type elemType = memref.getType().cast().getElementType(); + Type elemType = llvm::cast(memref.getType()).getElementType(); Attribute memorySpace = - memref.getType().cast().getMemorySpace(); + llvm::cast(memref.getType()).getMemorySpace(); auto newMemrefType = MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace); if (auto execOp = dyn_cast(alloc->getParentOp())) { @@ -3937,17 +3933,17 @@ struct ShrinkMemrefSizesByAccessPattern auto static_sizes = subViewOp.getStaticSizes(); auto static_strides = subViewOp.getStaticStrides(); // Get MemRefType after shrinkage. - Type elemType = - subViewOp.getSource().getType().cast().getElementType(); + Type elemType = llvm::cast(subViewOp.getSource().getType()) + .getElementType(); Attribute memorySpace = - subViewOp.getSource().getType().cast().getMemorySpace(); + llvm::cast(subViewOp.getSource().getType()) + .getMemorySpace(); auto shrunkMemrefType = MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace); MemRefType inferredSubViewOutputTy = - memref::SubViewOp::inferResultType( + llvm::cast(memref::SubViewOp::inferResultType( shrunkMemrefType, subViewOp.getStaticOffsets(), - subViewOp.getStaticSizes(), subViewOp.getStaticStrides()) - .cast(); + subViewOp.getStaticSizes(), subViewOp.getStaticStrides())); for (unsigned i = 0; i < static_sizes.size(); i++) { if (static_sizes[i] < 0) { if (*getConstantIntValue(*subview_sizes++) != diff --git a/mlir/lib/Transform/AIRLinalgCodegen.cpp b/mlir/lib/Transform/AIRLinalgCodegen.cpp index c67807532..9f1017a26 100644 --- a/mlir/lib/Transform/AIRLinalgCodegen.cpp +++ b/mlir/lib/Transform/AIRLinalgCodegen.cpp @@ -167,7 +167,7 @@ struct MemrefsPattern : public OpRewritePattern { // LogicalResult matchAndRewrite(memref::DimOp op, // PatternRewriter &rewriter) const override { -// auto operTy = op.memrefOrTensor().getType().dyn_cast(); +// auto operTy = llvm::dyn_cast(op.memrefOrTensor().getType()); // if (!operTy.hasStaticShape()) // return failure(); @@ -962,7 +962,7 @@ FailureOr static pipelineReduceLinalgOp( b.setInsertionPointToStart(stageBlock); if (i) { - auto ty = tiledOperands[resultIdx].getType().cast(); + auto ty = llvm::cast(tiledOperands[resultIdx].getType()); auto alloc = b.create( loc, MemRefType::get(ty.getShape(), ty.getElementType(), AffineMap(), (int)air::MemorySpace::L1)); @@ -1190,7 +1190,7 @@ class AIRLinalgCodegen affine::makeComposedFoldedMultiResultAffineApply( b, loc, shapeSizesToLoopsMap, allShapeSizes); for (auto size : shapeSizes) { - if (auto v = size.dyn_cast()) { + if (auto v = llvm::dyn_cast(size)) { auto c = dyn_cast(v.getDefiningOp()); if (!c) { LLVM_DEBUG(llvm::outs() << "Found non-constant dim!\n"); @@ -1198,8 +1198,8 @@ class AIRLinalgCodegen } tripCounts.push_back(c.value()); } else { - auto a = size.dyn_cast(); - auto c = a.dyn_cast(); + auto a = llvm::dyn_cast(size); + auto c = llvm::dyn_cast(a); if (!c) { LLVM_DEBUG(llvm::outs() << "unhandled addr!\n"); return {}; @@ -1847,7 +1847,7 @@ transform::LinalgTileOp::apply(TransformRewriter &rewriter, for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + llvm::isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " @@ -1880,9 +1880,9 @@ transform::LinalgTileOp::apply(TransformRewriter &rewriter, sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; for (OpFoldResult ofr : getMixedSizes()) { - if (auto attr = ofr.dyn_cast()) { + if (auto attr = llvm::dyn_cast(ofr)) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), llvm::cast(attr).getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); @@ -1910,9 +1910,10 @@ transform::LinalgTileOp::apply(TransformRewriter &rewriter, loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(llvm::cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(llvm::cast(getLoops()[en.index()]), + en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2108,7 +2109,7 @@ transform::LinalgPromoteOp::apply(transform::TransformRewriter &rewriter, if (!transformed.size()) return emitDefaultDefiniteFailure(payloadOps[0]); - results.set(getResult().cast(), transformed.getArrayRef()); + results.set(llvm::cast(getResult()), transformed.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -2233,7 +2234,7 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( llvm::to_vector(state.getPayloadOps(getProducerOp())); // If nothing to fuse, propagate success. if (producerOps.empty()) { - results.set(getFusedOp().cast(), + results.set(llvm::cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } @@ -2270,7 +2271,7 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( return containingOp->isAncestor(op); }); if (numUsesInContainingOp == 0) { - results.set(getFusedOp().cast(), ArrayRef()); + results.set(llvm::cast(getFusedOp()), ArrayRef()); Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "producer_op does not have uses in the container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); @@ -2288,11 +2289,11 @@ DiagnosedSilenceableFailure transform::FuseIntoContainingMemrefOp::apply( fusedOps.push_back(tiled); rewriter.eraseOp(producerOp); - results.set(getFusedOp().cast(), fusedOps); + results.set(llvm::cast(getFusedOp()), fusedOps); return DiagnosedSilenceableFailure::success(); } - results.set(getFusedOp().cast(), ArrayRef()); + results.set(llvm::cast(getFusedOp()), ArrayRef()); return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index be6895944..541cbe64a 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -162,8 +162,8 @@ void AIRPromoteUniformL1Dma::runOnOperation() { if (!uniform) continue; - auto src_type = memcpyOp.getSrc().getType().cast(); - auto dst_type = memcpyOp.getDst().getType().cast(); + auto src_type = llvm::cast(memcpyOp.getSrc().getType()); + auto dst_type = llvm::cast(memcpyOp.getDst().getType()); auto src_space = src_type.getMemorySpaceAsInt(); auto dst_space = dst_type.getMemorySpaceAsInt(); @@ -199,8 +199,8 @@ void AIRPromoteUniformL1Dma::runOnOperation() { launch_operands.push_back(alloc.getResult()); launch->setOperands(launch_operands); launch.getBody().front().addArgument(alloc.getType(), loc); - auto sizeAttr = launch->getAttr("operand_segment_sizes") - .cast<::mlir::DenseIntElementsAttr>(); + auto sizeAttr = llvm::cast<::mlir::DenseIntElementsAttr>( + launch->getAttr("operand_segment_sizes")); const uint32_t *it = &*sizeAttr.value_begin(); auto newAttr = DenseIntElementsAttr::get(sizeAttr.getType(), {it[0], it[1], it[2], it[3] + 1}); @@ -983,8 +983,8 @@ void AIRCollapseHerdPass::runOnOperation() { // Determine the current induction value's current loop iteration Value iv_1 = insideBuilder.create(loc, h.getIds()[1], old_upper_b_v); - h.getIds()[1].cast().replaceAllUsesExcept(iv_1, - iv_1.getDefiningOp()); + llvm::cast(h.getIds()[1]) + .replaceAllUsesExcept(iv_1, iv_1.getDefiningOp()); // Remove the effect of the current induction value to prepare for // the next value. @@ -1176,7 +1176,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref( SmallVector &puts, SmallVector &gets, int dim, std::string splitType = "") { auto memref = puts.front().getMemref(); - MemRefType ty = memref.getType().cast(); + MemRefType ty = llvm::cast(memref.getType()); auto allocOp = memref.getDefiningOp(); auto loc = allocOp->getLoc(); Operation *deallocOp = nullptr; @@ -1358,9 +1358,7 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( SmallVector allocOps; func.walk([&](memref::AllocOp allocOp) { if (allocOp->getParentOfType() && - allocOp.getMemref() - .getType() - .cast() + llvm::cast(allocOp.getMemref().getType()) .getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { allocOps.push_back(allocOp); } @@ -1469,9 +1467,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { SmallVector allocOps; func.walk([&](memref::AllocOp allocOp) { if (allocOp->getParentOfType() && - allocOp.getMemref() - .getType() - .cast() + llvm::cast(allocOp.getMemref().getType()) .getMemorySpaceAsInt() == (int)air::MemorySpace::L2) { allocOps.push_back(allocOp); } diff --git a/mlir/lib/Transform/ReturnEliminationPass.cpp b/mlir/lib/Transform/ReturnEliminationPass.cpp index 8d28695f4..94501034b 100644 --- a/mlir/lib/Transform/ReturnEliminationPass.cpp +++ b/mlir/lib/Transform/ReturnEliminationPass.cpp @@ -63,11 +63,11 @@ class ReturnEliminationPass callOp.arg_operand_end()}; for (auto v : callOp.getResults()) { - if (!v.getType().isa()) + if (!llvm::isa(v.getType())) llvm_unreachable("function returns non-memref"); if (!valueMap.count(v)) { - valueMap[v] = builder->create(op->getLoc(), - v.getType().cast()); + valueMap[v] = builder->create( + op->getLoc(), llvm::cast(v.getType())); } v.replaceAllUsesWith(valueMap[v]); newCallArgs.push_back(valueMap[v]); @@ -94,9 +94,9 @@ class ReturnEliminationPass } for (Value v : op->getOperands()) { - if (!v.getType().isa()) + if (!llvm::isa(v.getType())) continue; - if (v.isa()) + if (llvm::isa(v)) continue; if (v.getDefiningOp()) runOn(v.getDefiningOp()); @@ -150,7 +150,7 @@ class ReturnEliminationPass valueMap[v] = BB.addArgument(v.getType(), retOp->getLoc()); for (Value v : operands) { - if (!v.getType().isa()) + if (!llvm::isa(v.getType())) llvm_unreachable("graph function returns non-memref"); if (v.getDefiningOp()) runOn(v.getDefiningOp()); @@ -159,7 +159,7 @@ class ReturnEliminationPass for (auto oi=BB.rbegin(),oe=BB.rend(); oi!=oe; ++oi) { Operation *o = &*oi; for (Value v : o->getResults()) { - if (v.getType().isa()) { + if (llvm::isa(v.getType())) { runOn(o); break; } diff --git a/mlir/lib/Util/CostModel.cpp b/mlir/lib/Util/CostModel.cpp index 58cc3ca4a..c940375d1 100644 --- a/mlir/lib/Util/CostModel.cpp +++ b/mlir/lib/Util/CostModel.cpp @@ -43,10 +43,9 @@ static uint64_t getTensorVolume(const ShapedType ty) { } static uint64_t getTensorVolume(const Type ty) { - if (auto t = ty.dyn_cast()) { + if (auto t = llvm::dyn_cast(ty)) { return getTensorVolume(t); - } - else { + } else { return 1; } } @@ -70,7 +69,7 @@ CostModel::getLinalgOpCounts(OpCountMap &map, linalg::LinalgOp op) { int64_t writes = 0; uint64_t footprint = 0; for (auto size : shapeSizes) { - if (auto v = size.dyn_cast()) { + if (auto v = llvm::dyn_cast(size)) { auto c = dyn_cast(v.getDefiningOp()); if (!c) { LLVM_DEBUG(llvm::outs() << "Found non-constant dim!\n"); @@ -78,8 +77,8 @@ CostModel::getLinalgOpCounts(OpCountMap &map, linalg::LinalgOp op) { } iters *= c.value(); } else { - auto a = size.dyn_cast(); - auto c = a.dyn_cast(); + auto a = llvm::dyn_cast(size); + auto c = llvm::dyn_cast(a); if (!c) { LLVM_DEBUG(llvm::outs() << "unhandled addr!\n"); return; diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 3616ae4dd..de5a33231 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -79,8 +79,8 @@ void traceDependentInductionVar(air::DmaMemcpyNdOp dmaNd_op, // Recursively trace dependency to loop induction vars for (auto operand : candidate_scalar_operands) { - if (operand && - operand.getType().isa()) { // Only tracing scalar operands + if (operand && llvm::isa( + operand.getType())) { // Only tracing scalar operands if (operand.getDefiningOp() && mlir::dyn_cast(operand.getDefiningOp())) { auto ancestor_async_op = @@ -151,8 +151,8 @@ void traceDependentInductionVar(air::AsyncOpInterface async_op, // Recursively trace dependency to loop induction vars for (auto operand : op->getOperands()) { - if (operand && - operand.getType().isa()) { // Only tracing scalar operands + if (operand && llvm::isa( + operand.getType())) { // Only tracing scalar operands if (operand.getDefiningOp() && mlir::dyn_cast(operand.getDefiningOp())) { auto ancestor_async_op = @@ -204,8 +204,8 @@ void traceDependentHerdId(Operation *async_op, // Recursively trace dependency to loop induction vars for (auto operand : op->getOperands()) { - if (operand && - operand.getType().isa()) { // Only tracing scalar operands + if (operand && llvm::isa( + operand.getType())) { // Only tracing scalar operands if (operand.getDefiningOp() && mlir::dyn_cast(operand.getDefiningOp())) { op_history.push_back(operand.getDefiningOp()); @@ -261,9 +261,8 @@ traceDependentHerdId(air::DmaMemcpyNdOp dmaNd_op) { // Recursively trace dependency to loop induction vars for (auto &elem : loop_dep_history) { if (std::get<0>(elem) && - std::get<0>(elem) - .getType() - .isa()) { // Only tracing scalar operands + llvm::isa( + std::get<0>(elem).getType())) { // Only tracing scalar operands if (std::get<0>(elem).getDefiningOp() && mlir::dyn_cast( std::get<0>(elem).getDefiningOp())) { @@ -318,7 +317,7 @@ void eraseAsyncDependencyFromAsyncOp(xilinx::air::AsyncOpInterface op, Value token) { if (!token) return; - if (!token.getType().isa()) + if (!llvm::isa(token.getType())) return; auto dependency_list = op.getAsyncDependencies(); if (!dependency_list.size()) @@ -403,7 +402,7 @@ Value getLoopCarriedTokenFromScfOp(scf::ParallelOp op) { return nullptr; } auto token = op.getInitVals()[0]; - if (!token.getType().isa()) { + if (!llvm::isa(token.getType())) { op->emitOpError("init_val is not an async token"); return nullptr; } @@ -417,7 +416,7 @@ Value getLoopCarriedTokenFromScfOp(scf::ForOp op, return nullptr; } auto token = op.getInitArgs()[0]; - if (!token.getType().isa()) { + if (!llvm::isa(token.getType())) { op->emitOpError("iter operand is not an async token"); return nullptr; } @@ -428,7 +427,7 @@ Value getLoopCarriedTokenFromScfOp(scf::ForOp op, return nullptr; } auto token = op.getRegionIterArgs()[0]; - if (!token.getType().isa()) { + if (!llvm::isa(token.getType())) { op->emitOpError("iter operand is not an async token"); return nullptr; } @@ -539,7 +538,7 @@ Value getAsyncTokenFromOp(Operation *op) { // Add async dependency to op if unique void addAsyncDependencyIfNewImpl(air::AsyncOpInterface op, Value token) { - if (!token.getType().isa()) { + if (!llvm::isa(token.getType())) { op->emitOpError("value is not an async token"); return; } @@ -630,7 +629,7 @@ void addAsyncDependencyIfNew(Operation *op, Value token) { bool isAsyncOp(Operation *op) { for (auto result : op->getResults()) { - if (result.getType().isa()) { + if (llvm::isa(result.getType())) { return true; } } @@ -1282,7 +1281,7 @@ Graph::VertexId dependencyCanonicalizer::addVertexFromExecuteOp( // Annotate memref's memory space std::string memorySpaceStr = getMemorySpaceAsString(alloc_child_op.getMemref()); - auto ty = alloc_child_op.getMemref().getType().cast(); + auto ty = llvm::cast(alloc_child_op.getMemref().getType()); detailed_description += "(" + memorySpaceStr + ", " + std::to_string(getTensorVolume(ty)) + ", " + getElementTypeAsString(ty) + ")"; @@ -1294,7 +1293,7 @@ Graph::VertexId dependencyCanonicalizer::addVertexFromExecuteOp( // Annotate memref's memory space std::string memorySpaceStr = getMemorySpaceAsString(dealloc_child_op.getMemref()); - auto ty = dealloc_child_op.getMemref().getType().cast(); + auto ty = llvm::cast(dealloc_child_op.getMemref().getType()); detailed_description += "(" + memorySpaceStr + ", " + std::to_string(getTensorVolume(ty)) + ", " + getElementTypeAsString(ty) + ")"; @@ -1993,7 +1992,7 @@ void dependencyCanonicalizer::redoDepTraceIfDepOnHier(func::FuncOp func) { void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand, air::AsyncOpInterface op, char rw, partialMemref *tile) { - if (!operand.getType().isa()) + if (!llvm::isa(operand.getType())) op->emitOpError("operand being traced is not a memref"); for (auto &u : operand.getUses()) { // If used in MemcpyInterface Op @@ -2001,7 +2000,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand, partialMemref memcpy_src, memcpy_dst; if (memcpy.getSrcMemref()) { unsigned numDimsSrc = - memcpy.getSrcMemref().getType().cast().getRank(); + llvm::cast(memcpy.getSrcMemref().getType()).getRank(); SmallVector src_indices; if (memcpy.getSrcOffsets().size()) { numDimsSrc = memcpy.getSrcOffsets().size(); @@ -2018,7 +2017,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand, } if (memcpy.getDstMemref()) { unsigned numDimsDst = - memcpy.getDstMemref().getType().cast().getRank(); + llvm::cast(memcpy.getDstMemref().getType()).getRank(); SmallVector dst_indices; if (memcpy.getDstOffsets().size()) { numDimsDst = memcpy.getDstOffsets().size(); @@ -2147,22 +2146,24 @@ void dependencyTracer::getPartialMemrefFromOp( if (auto sink_op_linalgop = dyn_cast(sink_op)) { for (auto linalg_ins : sink_op_linalgop.getDpsInputOperands()) { auto ins_value = linalg_ins->get(); - if (ins_value.getType().isa()) { - unsigned memRefRank = ins_value.getType().cast().getRank(); + if (llvm::isa(ins_value.getType())) { + unsigned memRefRank = + llvm::cast(ins_value.getType()).getRank(); partialMemref tile = createPartialMemref(ins_value, memRefRank); sink_op_memref_reads.push_back(tile); - } else if (ins_value.getType().isa()) { + } else if (llvm::isa(ins_value.getType())) { sink_op_scalar_ins.push_back(ins_value); } } for (auto outs_value : sink_op_linalgop.getDpsInits()) { - if (outs_value.getType().isa()) { - unsigned memRefRank = outs_value.getType().cast().getRank(); + if (llvm::isa(outs_value.getType())) { + unsigned memRefRank = + llvm::cast(outs_value.getType()).getRank(); partialMemref tile = createPartialMemref(outs_value, memRefRank); sink_op_memref_reads.push_back( tile); // linalg op both reads and writes the output memref sink_op_memref_writes.push_back(tile); - } else if (outs_value.getType().isa()) { + } else if (llvm::isa(outs_value.getType())) { sink_op_scalar_ins.push_back(outs_value); // linalg op both reads and // writes the output memref sink_op_scalar_outs.push_back(outs_value); @@ -2170,12 +2171,12 @@ void dependencyTracer::getPartialMemrefFromOp( } if (sink_op_linalgop->getNumResults()) { for (auto linalg_results : sink_op_linalgop->getResults()) { - if (linalg_results.getType().isa()) { + if (llvm::isa(linalg_results.getType())) { unsigned memRefRank = - linalg_results.getType().cast().getRank(); + llvm::cast(linalg_results.getType()).getRank(); partialMemref tile = createPartialMemref(linalg_results, memRefRank); sink_op_memref_writes.push_back(tile); - } else if (linalg_results.getType().isa()) { + } else if (llvm::isa(linalg_results.getType())) { sink_op_scalar_outs.push_back(linalg_results); } } @@ -2185,7 +2186,8 @@ void dependencyTracer::getPartialMemrefFromOp( // If the sink op is memref::dealloc else if (auto sink_op_memdealloc = dyn_cast(sink_op)) { unsigned memRefRank = - sink_op_memdealloc.getMemref().getType().cast().getRank(); + llvm::cast(sink_op_memdealloc.getMemref().getType()) + .getRank(); partialMemref tile = createPartialMemref(sink_op_memdealloc.getMemref(), memRefRank); sink_op_memref_reads.push_back(tile); @@ -2196,12 +2198,14 @@ void dependencyTracer::getPartialMemrefFromOp( // If the sink op is memref::copy else if (auto sink_op_memref_copy = dyn_cast(sink_op)) { unsigned memRefRankSrc = - sink_op_memref_copy.getSource().getType().cast().getRank(); + llvm::cast(sink_op_memref_copy.getSource().getType()) + .getRank(); partialMemref tileSrc = createPartialMemref(sink_op_memref_copy.getSource(), memRefRankSrc); sink_op_memref_reads.push_back(tileSrc); unsigned memRefRankDst = - sink_op_memref_copy.getTarget().getType().cast().getRank(); + llvm::cast(sink_op_memref_copy.getTarget().getType()) + .getRank(); partialMemref tileDst = createPartialMemref(sink_op_memref_copy.getTarget(), memRefRankDst); sink_op_memref_reads.push_back(tileDst); @@ -2214,7 +2218,8 @@ void dependencyTracer::getPartialMemrefFromOp( if (sink_op_memcpy.getSrcMemref()) { SmallVector src_indices; unsigned numDimsSrc = - sink_op_memcpy.getSrcMemref().getType().cast().getRank(); + llvm::cast(sink_op_memcpy.getSrcMemref().getType()) + .getRank(); for (unsigned i = 0; i < sink_op_memcpy.getSrcOffsets().size(); i++) sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcOffsets()[i]); for (unsigned i = 0; i < sink_op_memcpy.getSrcSizes().size(); i++) @@ -2238,7 +2243,8 @@ void dependencyTracer::getPartialMemrefFromOp( if (sink_op_memcpy.getDstMemref()) { SmallVector dst_indices; unsigned numDimsDst = - sink_op_memcpy.getDstMemref().getType().cast().getRank(); + llvm::cast(sink_op_memcpy.getDstMemref().getType()) + .getRank(); // air.dmamemcpynd op's scalar operands for (unsigned i = 0; i < sink_op_memcpy.getDstOffsets().size(); i++) sink_op_scalar_outs.push_back(sink_op_memcpy.getDstOffsets()[i]); @@ -2287,13 +2293,14 @@ void dependencyTracer::getPartialMemrefFromOp( // If the sink op is an unknown op else { for (auto sink_op_op : sink_op->getOperands()) { - if (sink_op_op.getType().isa()) { - unsigned memRefRank = sink_op_op.getType().cast().getRank(); + if (llvm::isa(sink_op_op.getType())) { + unsigned memRefRank = + llvm::cast(sink_op_op.getType()).getRank(); partialMemref tile = createPartialMemref(sink_op_op, memRefRank); sink_op_memref_reads.push_back( tile); // Assuming all operands are both read and written to sink_op_memref_writes.push_back(tile); - } else if (sink_op_op.getType().isa()) { + } else if (llvm::isa(sink_op_op.getType())) { sink_op_scalar_ins.push_back(sink_op_op); // Assuming all operands are // both read and written to sink_op_scalar_outs.push_back(sink_op_op); @@ -2301,12 +2308,12 @@ void dependencyTracer::getPartialMemrefFromOp( } if (sink_op->getNumResults()) { for (auto sink_op_results : sink_op->getResults()) { - if (sink_op_results.getType().isa()) { + if (llvm::isa(sink_op_results.getType())) { unsigned memRefRank = - sink_op_results.getType().cast().getRank(); + llvm::cast(sink_op_results.getType()).getRank(); partialMemref tile = createPartialMemref(sink_op_results, memRefRank); sink_op_memref_writes.push_back(tile); - } else if (sink_op_results.getType().isa()) { + } else if (llvm::isa(sink_op_results.getType())) { sink_op_scalar_outs.push_back(sink_op_results); } } @@ -2326,7 +2333,7 @@ void dependencyTracer::addDependencyBetweenOps(Operation *source, return; } } - for (auto parent = source->getParentOp(); !isa(parent); + for (auto parent = source->getParentOp(); !llvm::isa(parent); parent = parent->getParentOp()) { if (parent->getBlock() == sink->getBlock() && parent->isBeforeInBlock(sink)) { @@ -2354,7 +2361,7 @@ bool dependencyTracer::areEqualIndexPartialMemrefs(partialMemref *tile_0, } char dependencyTracer::checkOperandReadOrWrite(mlir::Value operand) { - if (!operand.getType().isa()) + if (!llvm::isa(operand.getType())) operand.getDefiningOp()->emitOpError( "operand being traced is not a memref"); bool foundWriteAccess = false; diff --git a/mlir/lib/Util/Runner.cpp b/mlir/lib/Util/Runner.cpp index 4180c08ef..ed8d0a295 100644 --- a/mlir/lib/Util/Runner.cpp +++ b/mlir/lib/Util/Runner.cpp @@ -179,8 +179,8 @@ class AIRRunner::AIRRunner_impl { c.op->emitOpError("has mismatching event type").attachNote() << "Has 'dma' as event type, but op isn't of type " "air::DmaMemcpyNdOp"; - MemRefType srcTy = Op.getSrcMemref().getType().cast(); - MemRefType dstTy = Op.getDstMemref().getType().cast(); + MemRefType srcTy = llvm::cast(Op.getSrcMemref().getType()); + MemRefType dstTy = llvm::cast(Op.getDstMemref().getType()); auto srcSpace = srcTy.getMemorySpaceAsInt(); auto dstSpace = dstTy.getMemorySpaceAsInt(); // if there is a size mismatch, it's because we're moving a tile of the @@ -196,12 +196,12 @@ class AIRRunner::AIRRunner_impl { c.op->emitOpError("has mismatching event type").attachNote() << "Has 'channel' as event type, but op isn't of type " "air::ChannelGetOp"; - MemRefType dstTy = getOp.getDst().getType().cast(); + MemRefType dstTy = llvm::cast(getOp.getDst().getType()); std::vector putOps = air::getTheOtherChannelOpThroughSymbol(getOp); if (!putOps.size()) getOp->emitOpError("found no put op for air::ChannelGetOp"); - MemRefType srcTy = putOps[0].getSrc().getType().cast(); + MemRefType srcTy = llvm::cast(putOps[0].getSrc().getType()); auto srcSpace = srcTy.getMemorySpaceAsInt(); auto dstSpace = dstTy.getMemorySpaceAsInt(); auto srcVolumn = getTransferVolumn(putOps[0]); @@ -622,7 +622,7 @@ class AIRRunner::AIRRunner_impl { } uint64_t getTransferVolumn(air::ChannelInterface op) { - MemRefType memTy = op.getMemref().getType().cast(); + MemRefType memTy = llvm::cast(op.getMemref().getType()); if (op.getSizes().empty()) return getTensorVolume(memTy); else diff --git a/mlir/lib/Util/Runner/RunnerNode.cpp b/mlir/lib/Util/Runner/RunnerNode.cpp index c09c1ab68..b0d310c51 100644 --- a/mlir/lib/Util/Runner/RunnerNode.cpp +++ b/mlir/lib/Util/Runner/RunnerNode.cpp @@ -655,7 +655,7 @@ class runnerNode { std::vector resource_pool; double memory_pool = this->getMemoriesPool(resource_pool); // Get memory allocation size - MemRefType ty = Op.getMemref().getType().cast(); + MemRefType ty = llvm::cast(Op.getMemref().getType()); double memory_allocated = this->getMemoryCostInBytes(ty, Op.getOperation()); if (memory_allocated <= memory_pool) { return true; @@ -667,7 +667,7 @@ class runnerNode { std::vector resource_pool; double memory_pool = this->getMemoriesPool(resource_pool, false); // Get memory allocation size - MemRefType ty = Op.getMemref().getType().cast(); + MemRefType ty = llvm::cast(Op.getMemref().getType()); double memory_deallocated = this->getMemoryCostInBytes(ty, Op.getOperation()); if (memory_deallocated <= memory_pool) { @@ -781,7 +781,7 @@ class runnerNode { std::vector resource_pool; this->getMemoriesPool(resource_pool); // Get memory size in bytes - MemRefType ty = Op.getMemref().getType().cast(); + MemRefType ty = llvm::cast(Op.getMemref().getType()); double memory_allocated = this->getMemoryCostInBytes(ty, Op.getOperation()); // Reserve resource this->allocateRunnerNodeToAllocateMemory(resource_pool, reserved_resources, @@ -793,7 +793,7 @@ class runnerNode { std::vector resource_pool; this->getMemoriesPool(resource_pool, false); // Get memory size in bytes - MemRefType ty = Op.getMemref().getType().cast(); + MemRefType ty = llvm::cast(Op.getMemref().getType()); double memory_deallocated = this->getMemoryCostInBytes(ty, Op.getOperation()); // Reserve resource diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 31e4f4d8f..20eba3aba 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -37,7 +37,7 @@ const StringLiteral air::LinalgTransforms::kLinalgTransformMarker = static std::string getMangledType(const Type ty) { std::stringstream ret; - if (const MemRefType mrt = ty.dyn_cast()) { + if (const MemRefType mrt = llvm::dyn_cast(ty)) { ret << "M"; ret << mrt.getMemorySpaceAsInt(); if (mrt.hasStaticShape()) { @@ -49,13 +49,13 @@ static std::string getMangledType(const Type ty) { } const Type elem = mrt.getElementType(); ret << getMangledType(elem); - } else if (FloatType ft = ty.dyn_cast()) { + } else if (FloatType ft = llvm::dyn_cast(ty)) { ret << "F" << ft.getWidth(); - } else if (const IntegerType it = ty.dyn_cast()) { + } else if (const IntegerType it = llvm::dyn_cast(ty)) { ret << "I" << it.getWidth(); - } else if (const IndexType it = ty.dyn_cast()) { + } else if (const IndexType it = llvm::dyn_cast(ty)) { ret << "I64"; - } else if (ty.dyn_cast()) { + } else if (llvm::dyn_cast(ty)) { ret << "E"; } else { Type t = ty; @@ -145,7 +145,7 @@ uint64_t air::getTensorVolume(const ShapedType ty) { } uint64_t air::getTensorVolume(const Type ty) { - if (auto t = ty.dyn_cast()) { + if (auto t = llvm::dyn_cast(ty)) { return getTensorVolume(t); } else { return 1; @@ -162,7 +162,7 @@ SmallVector air::getTensorShape(const ShapedType ty) { } SmallVector air::getTensorShape(const Type ty) { - if (auto t = ty.dyn_cast()) { + if (auto t = llvm::dyn_cast(ty)) { return getTensorShape(t); } else { return SmallVector(1); @@ -170,7 +170,7 @@ SmallVector air::getTensorShape(const Type ty) { } std::string air::getElementTypeAsString(const mlir::Type ty) { - if (auto st = ty.dyn_cast()) { + if (auto st = llvm::dyn_cast(ty)) { return to_string(st.getElementType()); } else { return to_string(ty); @@ -179,7 +179,7 @@ std::string air::getElementTypeAsString(const mlir::Type ty) { // An incomplete lookup table of common data types uint64_t air::getElementSizeInBytes(const mlir::Type ty) { - if (auto memrefTy = ty.cast()) { + if (auto memrefTy = llvm::cast(ty)) { return memrefTy.getElementTypeBitWidth() / 8; } auto typeAsString = getElementTypeAsString(ty); @@ -201,7 +201,7 @@ uint64_t air::getElementSizeInBytes(const mlir::Type ty) { // Get the parent scf.for op of an iter_arg scf::ForOp air::getForRegionIterArgsOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return scf::ForOp(); if (!ivArg.getOwner()) { @@ -225,7 +225,7 @@ scf::ParallelOp air::getParallelRegionInitValsOwner(Operation *op, Value val) { // Get the parent air.launch_herd op of a tile id air::HerdOp air::getHerdArgOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return air::HerdOp(); if (!ivArg.getOwner()) { @@ -238,7 +238,7 @@ air::HerdOp air::getHerdArgOwner(Value val) { // Get the parent air.hierarchy op of a tile id air::HierarchyInterface air::getHierarchyArgOwner(Value val) { - auto ivArg = val.dyn_cast(); + auto ivArg = llvm::dyn_cast(val); if (!ivArg) return air::HierarchyInterface(); if (!ivArg.getOwner()) { @@ -365,12 +365,12 @@ std::string air::createChannelName(Operation *scope) { // Return memory space as string std::string air::getMemorySpaceAsString(Value memref) { - if (!memref.getType().isa()) { + if (!llvm::isa(memref.getType())) { memref.getDefiningOp()->emitOpError("value returned is not a memref"); return ""; } auto memory_space_as_int = - memref.getType().dyn_cast().getMemorySpaceAsInt(); + llvm::dyn_cast(memref.getType()).getMemorySpaceAsInt(); std::string memorySpaceStr = ""; if (memory_space_as_int == (int)air::MemorySpace::L1) { memorySpaceStr = "L1"; @@ -1100,7 +1100,7 @@ SmallVector air::getDataAccessShapeFromMemcpyOp( return overall_access_bounds; } -void updateAccessPatternByScfForNest( +static void updateAccessPatternByScfForNest( std::tuple, SmallVector, SmallVector> &pattern, SmallVector indices, OpBuilder builder) { @@ -1155,7 +1155,7 @@ air::writeAccessPattern(memref::SubViewOp subview) { auto static_strides = subview.getStaticStrides(); // Get strided layout from subview op's output MemRefType if (auto strided = llvm::dyn_cast( - subview.getResult().getType().cast().getLayout())) + llvm::cast(subview.getResult().getType()).getLayout())) static_strides = strided.getStrides(); auto loc = subview.getLoc(); @@ -1190,8 +1190,8 @@ air::writeAccessPattern(mlir::vector::TransferReadOp readOp) { OpBuilder builder(readOp); std::tuple, SmallVector, SmallVector> pattern; - auto vectorTy = readOp.getVector().getType().cast(); - auto memrefTy = readOp.getSource().getType().cast(); + auto vectorTy = llvm::cast(readOp.getVector().getType()); + auto memrefTy = llvm::cast(readOp.getSource().getType()); assert(vectorTy && "Not a vector"); assert(memrefTy && "Not a memref"); // Initialize wraps and strides based on the unshrunk memref shape. @@ -1211,8 +1211,8 @@ air::writeAccessPattern(mlir::vector::TransferWriteOp writeOp) { OpBuilder builder(writeOp); std::tuple, SmallVector, SmallVector> pattern; - auto memrefTy = writeOp.getSource().getType().cast(); - auto vectorTy = writeOp.getVector().getType().cast(); + auto memrefTy = llvm::cast(writeOp.getSource().getType()); + auto vectorTy = llvm::cast(writeOp.getVector().getType()); assert(memrefTy && "Not a memref"); assert(vectorTy && "Not a vector"); // Initialize wraps and strides based on the unshrunk memref shape.