From 08947ac210a372f68b820a5e18339a602ab28655 Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:21:04 +0800 Subject: [PATCH] [PIR]Add constraint function for dynamic shape (#57459) * build shape. * add applyOpConstraint func --- .../pir/dialect/operator/ir/op_dialect.cc | 3 +- paddle/pir/dialect/shape/ir/shape_dialect.cc | 8 + paddle/pir/dialect/shape/ir/shape_dialect.h | 8 +- paddle/pir/dialect/shape/ir/shape_op.cc | 25 ++- paddle/pir/dialect/shape/ir/shape_op.h | 18 +- .../transforms/shape_optimization_pass.cc | 36 +++- paddle/pir/dialect/shape/utils/shape_utils.cc | 173 +++++++++++++++++- paddle/pir/dialect/shape/utils/shape_utils.h | 34 ++++ .../pir/shape_dialect/constraint_pass_test.cc | 62 ++++++- .../cpp/pir/shape_dialect/symbolic_op_test.cc | 12 +- 10 files changed, 337 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 2c85ea18d3da3..6e7e49e3bcee9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" #include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" +#include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/utils.h" @@ -66,7 +67,7 @@ void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const { if (auto tensor_type = type.dyn_cast()) { os << "tensor<"; for (auto d : phi::vectorize(tensor_type.dims())) { - os << d; + pir::ShapedTypeInterface::isDynamic(d) ? os << "?" : os << d; os << "x"; } tensor_type.dtype().Print(os); diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 611d2d95c4810..4367670156efc 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -31,6 +31,14 @@ void ShapeDialect::initialize() { TensorDimOp>(); } +void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { + if (auto func_op = op->dyn_cast()) { + func_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + } // namespace dialect } // namespace pir diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.h b/paddle/pir/dialect/shape/ir/shape_dialect.h index 16d5d2ea68e07..b4ae3aa617210 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.h +++ b/paddle/pir/dialect/shape/ir/shape_dialect.h @@ -21,16 +21,18 @@ namespace dialect { /// /// \brief Shape Dialect: /// -class IR_API ShapeDialect : public pir::Dialect { +class IR_API ShapeDialect : public Dialect { public: - explicit ShapeDialect(pir::IrContext *context); + explicit ShapeDialect(IrContext* context); /// /// \brief Each Dialect needs to provide a name function to return the name of /// the Dialect. /// /// \return The name of this Dialect. /// - static const char *name() { return "shape"; } + static const char* name() { return "shape"; } + void PrintOperation(Operation* op, + IrPrinter& printer) const override; // NOLINT private: void initialize(); diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index 530a46d0328eb..7220adb8f1dbf 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -177,7 +177,7 @@ void TieProductEqualOp::Build(Builder &builder, argument.AddInputs(rhs); } -std::vector TieProductEqualOp::getLhs() { +std::vector TieProductEqualOp::lhs() { int64_t lhs_len = attribute("lhs_len").data(); std::vector res; for (uint32_t idx = 0; idx < lhs_len; idx++) { @@ -185,7 +185,7 @@ std::vector TieProductEqualOp::getLhs() { } return res; } -std::vector TieProductEqualOp::getRhs() { +std::vector TieProductEqualOp::rhs() { int64_t lhs_len = attribute("lhs_len").data(); int64_t rhs_len = attribute("rhs_len").data(); std::vector res; @@ -211,9 +211,9 @@ void TieShapeOp::Build(Builder &builder, // NOLINT argument.AddInputs(dims); } -Value TieShapeOp::getValue() { return operand_source(0); } +Value TieShapeOp::value() { return operand_source(0); } -std::vector TieShapeOp::getShapeDimIndexes() { +std::vector TieShapeOp::dims() { std::vector res; for (uint32_t i = 1; i < num_operands(); i++) { res.push_back(operand_source(i)); @@ -231,6 +231,17 @@ Block *FuncOp::block() { return region.front(); } +void FuncOp::Print(IrPrinter &printer) { + auto &os = printer.os; + os << " shape.func () "; + os << "{"; + for (auto item : *block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n }"; +} + void TensorDimOp::Build(Builder &builder, OperationArgument &argument, Value source, @@ -245,16 +256,16 @@ void TensorDimOp::Build(Builder &builder, int64_t index) { OpResult indexValue = builder - .Build(Int64Attribute::get(IrContext::Instance(), 2), + .Build(Int64Attribute::get(IrContext::Instance(), index), IndexType::get(IrContext::Instance())) ->result(0); argument.AddInputs({source, indexValue}); argument.output_types.emplace_back(IndexType::get(IrContext::Instance())); } -Value TensorDimOp::getSource() { return operand_source(0); } +Value TensorDimOp::source() { return operand_source(0); } -Value TensorDimOp::getIndex() { return operand_source(1); } +Value TensorDimOp::index() { return operand_source(1); } } // namespace dialect } // namespace pir diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index 67a2372bf4fca..3163d404a61ee 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -16,6 +16,7 @@ #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/op_base.h" namespace pir { @@ -75,7 +76,7 @@ class IR_API DimOp : public Op { const std::string getName(); void setName(std::string attrValue); - pir::OpResult out() { return result(0); } + OpResult out() { return result(0); } void Verify() {} }; @@ -96,8 +97,8 @@ class IR_API TieProductEqualOp : public Op { OperationArgument &argument, // NOLINT const std::vector &lhs, const std::vector &rhs); - std::vector getLhs(); - std::vector getRhs(); + std::vector lhs(); + std::vector rhs(); void Verify() {} }; @@ -117,8 +118,8 @@ class IR_API TieShapeOp : public Op { OperationArgument &argument, // NOLINT Value input, const std::vector &dims); - Value getValue(); - std::vector getShapeDimIndexes(); + Value value(); + std::vector dims(); void Verify() {} }; @@ -132,7 +133,8 @@ class IR_API FuncOp : public Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument); // NOLINT - pir::Block *block(); + void Print(IrPrinter &printer); // NOLINT + Block *block(); void Verify() {} }; @@ -152,8 +154,8 @@ class IR_API TensorDimOp : public Op { OperationArgument &argument, // NOLINT Value source, int64_t index); - Value getIndex(); - Value getSource(); + Value index(); + Value source(); OpResult out() { return result(0); } void Verify() {} }; diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index 3fcc784de8ab3..6bbb918ebc1f1 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -18,10 +18,14 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" namespace { +using PassPipelineRunner = + std::function; bool InsertTieShapeOnValue(pir::Value value, pir::Builder& builder) { // NOLINT @@ -41,8 +45,9 @@ bool InsertTieShapeOnRegion(pir::Region* region); bool InsertTieShapeOnOperation(pir::Operation* op, pir::Builder& builder) { // NOLINT - if (op->isa()) return true; - // TODO(liujinnan): skip the specialized Ops. + // TODO(zhangbo63): skip more specialized Ops. + if (op->isa() || op->isa()) + return true; for (size_t i = 0; i < op->num_regions(); ++i) { if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; @@ -55,7 +60,7 @@ bool InsertTieShapeOnOperation(pir::Operation* op, return true; } -bool insertTieShapeOnBlock(pir::Block* block) { +bool InsertTieShapeOnBlock(pir::Block* block) { pir::Builder builder = pir::Builder(pir::IrContext::Instance(), block, block->begin()); // TODO(liujinnan): mapping block arguments @@ -70,7 +75,7 @@ bool insertTieShapeOnBlock(pir::Block* block) { bool InsertTieShapeOnRegion(pir::Region* region) { for (pir::Block* block : *region) { - if (!insertTieShapeOnBlock(block)) return false; + if (!InsertTieShapeOnBlock(block)) return false; } return true; } @@ -81,6 +86,20 @@ bool MaterializeShapeComputation(pir::ModuleOp m) { return true; } +bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { + // TODO(liujinnan): Do some Canonicalizer. + pir::SymbolicDimMgr mgr(m); + IR_ENFORCE(mgr.Load(), + "SymbolicDimMgr Load failed in OptimizeShapeComputation."); + pir::ShapeComputationIRAnalysis analysis(m, mgr); + if (!analysis.Run()) { + return false; + } + IR_ENFORCE(mgr.Save(), + "SymbolicDimMgr save failed in OptimizeShapeComputation."); + return true; +} + class ShapeOptimizationPass : public pir::Pass { public: ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} @@ -89,10 +108,17 @@ class ShapeOptimizationPass : public pir::Pass { auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); MaterializeShapeComputation(module_op); + // runner is for Canonicalizer. + PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { + return pm.Run(m.program()); + }; + if (!OptimizeShapeComputation(module_op, runner)) { + return; + } } bool CanApplyOn(pir::Operation* op) const override { - return op->name() == "builtin.module" && op->num_regions() > 0; + return op->isa() && op->num_regions() > 0; } }; diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index ac545b2c3a7ee..7998808cc20b1 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -101,8 +101,8 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() { for (auto op : constraint_vec) { SymbolicDimProduct lhs, rhs; - if (!build_sym_product(op.getLhs(), lhs) || - !build_sym_product(op.getRhs(), rhs) || + if (!build_sym_product(op.lhs(), lhs) || + !build_sym_product(op.rhs(), rhs) || !MapSymbolicDimProductEqual(lhs, rhs)) return false; } @@ -645,7 +645,7 @@ SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) for (auto op : *(m_.block())) { auto tieShapeOp = op->dyn_cast(); if (!tieShapeOp) continue; - Value result = tieShapeOp.getValue(); + Value result = tieShapeOp.value(); auto& symbols = value2SymDims_[result]; auto attrs = tieShapeOp @@ -724,4 +724,171 @@ bool SymbolicDimShapeAnalysis::IsProductEqual(Value lhs, return mgr_.IsSymbolicDimProductEqual(lhsProd, rhsProd); } + +ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, + SymbolicDimMgr& mgr) + : m_(m), mgr_(mgr) {} + +bool ShapeComputationIRAnalysis::Run() { + // Make sure only run once. + if (initialized_) return false; + initialized_ = true; + auto buildShapeFunc = + std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, + this, + std::placeholders::_1); + if (!RunOnRegion(&(m_->region(0)), buildShapeFunc)) return false; + auto applyOpConstraintFunc = + std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, + this, + std::placeholders::_1); + if (!RunOnRegion(&(m_->region(0)), applyOpConstraintFunc)) return false; + return true; +} + +bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { + for (Block* block : *region) { + if (!RunOnBlock(block, fn)) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { + // TODO(liujinnan): mapping block arguments + + std::vector op_list; + for (Operation* op : *block) op_list.push_back(op); + for (Operation* op : op_list) { + if (!RunOnOperation(op, fn)) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { + for (size_t i = 0; i < op->num_regions(); ++i) { + if (!RunOnRegion(&(op->region(i)), fn)) return false; + } + return fn(op); +} + +bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { + if (op->isa()) return true; + if (op->isa()) { + Value value = op->operand_source(0); + std::vector symbols; + if (op->HasAttribute(SymbolicDim::getSymbolicDimAttrName())) { + auto attrs = + op->attribute(SymbolicDim::getSymbolicDimAttrName()) + .AsVector(); + for (Attribute attr : attrs) { + auto sym = mgr_.symbolTable().Lookup( + attr.dyn_cast().AsString()); + assert(sym); + SymbolicDim root = mgr_.GetRootSymbolicDim(sym); + symbols.push_back(root); + } + } else { + symbols = mgr_.CreateSymbolicDimsForRankedValue(value); + std::vector attrs; + for (SymbolicDim sym : symbols) { + Attribute rootSymbol = + StrAttribute::get(m_->ir_context(), sym.getSymName()); + attrs.push_back(rootSymbol); + } + op->set_attribute(SymbolicDim::getSymbolicDimAttrName(), + ArrayAttribute::get(m_->ir_context(), attrs)); + } + rankedTensor2SymDims_[value] = std::move(symbols); + return true; + } + for (size_t i = 0; i < op->num_results(); ++i) { + if (!BuildShapeOnValue(op->result(i))) return false; + } + return true; +} + +bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { + Type ty = value.type(); + if (IsIntOrIndex(ty)) { + SymbolicDim sym = mgr_.NewSymbolicDim(); + value2SymDim_[value] = sym; + } else if (IsCandidateShapeTensorType(ty)) { + auto shapedTy = ty.dyn_cast_interface(); + std::vector symbols; + for (size_t i = 0, d = shapedTy.getShape()[0]; i < d; ++i) + symbols.push_back(mgr_.NewSymbolicDim()); + shapeTensor2SymDims_[value] = std::move(symbols); + } + return true; +} + +bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { + IR_ENFORCE(ApplyIndexOpConstraint(op), + "Fail to apply constraint for index op"); + IR_ENFORCE(ApplyTieShapeOpConstraint(op), + "Fail to apply constraint for tie_shape op"); + + // TODO(zhangbo63): add more constraints + return true; +} + +bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { + if (op->num_results() == 0) return true; + + Type ty = op->result(0).type(); + if (!IsIntOrIndex(ty)) return true; + + if (auto dimOp = op->dyn_cast()) { + int64_t dimIndex = dimOp.index() + .dyn_cast() + .owner() + ->attribute("value") + .data(); + value2SymDim_[dimOp.out()].updateKnownNonNegative(true); + if (!mgr_.MapSymbolicDimEqual( + value2SymDim_[dimOp.out()], + rankedTensor2SymDims_[dimOp.source()][dimIndex])) { + return false; + } + + } else if (auto constOp = op->dyn_cast()) { + int64_t val = constOp.value().dyn_cast().data(); + if (!mgr_.MapSymbolicDimEqual(value2SymDim_[op->result(0)], + mgr_.NewConstantSymbolicDim(val))) { + return false; + } + } + // TODO(zhangbo63): add support for reifyInferShape. (e.g. mul/add) + return true; +} + +bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { + if (auto tieShape = op->dyn_cast()) { + auto& value = rankedTensor2SymDims_[op->operand_source(0)]; + for (size_t idx = 0; idx < tieShape.dims().size(); ++idx) { + if (!mgr_.MapSymbolicDimEqual(value2SymDim_[tieShape.dims()[idx]], + value[idx])) + return false; + mgr_.GetRootSymbolicDim(value[idx]).updateKnownNonNegative(true); + } + } + return true; +} + +bool IsIntOrIndex(Type type) { + return type.isa() || type.isa() || + type.isa() || type.isa() || + type.isa() || type.isa(); +} + +bool IsCandidateShapeTensorType(Type ty) { + if (auto tensorTy = ty.dyn_cast()) { + auto shapedTy = tensorTy.dyn_cast_interface(); + return (shapedTy.getRank() == 1 && shapedTy.hasStaticShape() && + IsIntOrIndex(shapedTy.getElementType()) && + shapedTy.getShape()[0] < 32); + } + return false; +} + } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index ab728fc6a00a0..bb6dd58cebb26 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -187,4 +187,38 @@ class SymbolicDimShapeAnalysis : public ShapeAnalysis { SymbolicDimMgr mgr_; std::unordered_map> value2SymDims_; }; + +class ShapeComputationIRAnalysis { + public: + using func = std::function; + explicit ShapeComputationIRAnalysis(ModuleOp m, + SymbolicDimMgr& mgr); // NOLINT + bool Run(); + + private: + bool RunOnRegion(Region* region, func fn); + bool RunOnBlock(Block* block, func fn); + bool RunOnOperation(Operation* op, func fn); + + bool BuildShapeOnOperation(Operation* op); + bool BuildShapeOnValue(Value value); + + bool ApplyOpConstraint(Operation* op); + bool ApplyIndexOpConstraint(Operation* op); + bool ApplyTieShapeOpConstraint(Operation* op); + + bool initialized_ = false; + ModuleOp m_; + SymbolicDimMgr& mgr_; + + std::unordered_map value2SymDim_; + + // shape tensor is the 1D ranked tensor with int/index dtype. + std::unordered_map> shapeTensor2SymDims_; + + std::unordered_map> rankedTensor2SymDims_; +}; + +bool IsIntOrIndex(Type type); +bool IsCandidateShapeTensorType(Type ty); } // namespace pir diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index b002c348a60f2..7c645044a09d0 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -28,6 +28,7 @@ #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/cast_utils.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/enforce.h" @@ -37,7 +38,9 @@ #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" @@ -58,15 +61,16 @@ pir::Operation *CreateDenseTensorOp( pir::IrContext *ctx, const phi::DDim &dims, const std::vector &attribute_names, - const std::vector &attributes) { + const std::vector &attributes, + const pir::Type &dtype = + pir::Float32Type::get(pir::IrContext::Instance())) { std::vector op_inputs = {}; - pir::Type fp32_dtype = pir::Float32Type::get(ctx); phi::DataLayout data_layout = phi::DataLayout::NCHW; phi::LoD lod = {{0, 1, 2}}; size_t offset = 0; std::vector op_output_types = { paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset)}; + ctx, dtype, dims, data_layout, lod, offset)}; pir::Operation *op = pir::Operation::Create(op_inputs, CreateAttributeMap(attribute_names, attributes), @@ -75,22 +79,62 @@ pir::Operation *CreateDenseTensorOp( return op; } -TEST(constraint_pass, materialize_shape) { +TEST(constraint_pass, materialize_and_build_shape) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); pir::PassManager pm(ctx); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - pir::Operation *op0 = - CreateDenseTensorOp(ctx, {-100000, 2}, {"op0_attr"}, {"op0_name"}); + pir::Operation *op0 = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op0_attr"}, {"op0_name"}); program.block()->push_back(op0); pir::Operation *op1 = - CreateDenseTensorOp(ctx, {-100000, 2, 2}, {"op1_attr"}, {"op1_name"}); + CreateDenseTensorOp(ctx, + {pir::ShapedTypeInterface::kDynamic, 2, 2}, + {"op1_attr"}, + {"op1_name"}); program.block()->push_back(op1); EXPECT_EQ(program.block()->size(), static_cast(2)); pm.AddPass(pir::CreateShapeOptimizationPass()); + + EXPECT_TRUE(pm.Run(&program)); + + // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 + 1 funcOp == 15 Ops. + EXPECT_EQ(program.block()->size(), static_cast(15)); + + std::stringstream ss; + program.Print(ss); + + LOG(INFO) << ss.str(); +} + +TEST(constraint_pass, shape_computation_run) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + pir::PassManager pm(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + builder.Build(); + pir::Operation *op0 = + CreateDenseTensorOp(ctx, + {2}, + {"op0_attr"}, + {"op0_name"}, + pir::Int64Type::get(pir::IrContext::Instance())); + program.block()->push_back(op0); + pir::Operation *op1 = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op1_attr"}, {"op1_name"}); + program.block()->push_back(op1); + + pm.AddPass(pir::CreateShapeOptimizationPass()); + EXPECT_TRUE(pm.Run(&program)); - // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 == 14 Ops. - EXPECT_EQ(program.block()->size(), static_cast(14)); + pir::SymbolicDimMgr mgr(program.module_op()); + EXPECT_TRUE(mgr.Load()); + pir::ShapeComputationIRAnalysis analysis(program.module_op(), mgr); + EXPECT_TRUE(analysis.Run()); + EXPECT_FALSE(analysis.Run()); + EXPECT_TRUE(mgr.Save()); } diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index 9f4f99857b807..b2b62c7b46aa9 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -429,8 +429,8 @@ TEST(shape_op, tie_product_equal) { 3, std::vector{dimOp0, dimOp1, dimOp2, dimOp3, dimOp4}); - std::vector lhs = tie_product_equal.getLhs(); - std::vector rhs = tie_product_equal.getRhs(); + std::vector lhs = tie_product_equal.lhs(); + std::vector rhs = tie_product_equal.rhs(); std::vector lhs_ref{dimOp0, dimOp1}; std::vector rhs_ref{dimOp2, dimOp3, dimOp4}; @@ -461,7 +461,7 @@ TEST(shape_op, tie_shape) { pir::dialect::TieShapeOp tieShapeOp = builder.Build(res); - pir::Value tieShapeOpValue = tieShapeOp.getValue(); + pir::Value tieShapeOpValue = tieShapeOp.value(); pir::Attribute attrS0 = pir::StrAttribute::get(ctx, "S0"); pir::Attribute attrS1 = pir::StrAttribute::get(ctx, "S1"); @@ -613,7 +613,7 @@ TEST(shape_op, tensor_dim) { EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); - EXPECT_EQ(tensorDimOp0.getSource(), resDenseTensorValue); - EXPECT_EQ(tensorDimOp1.getSource(), resDenseTensorValue); - EXPECT_EQ(tensorDimOp1.getIndex(), indexValue); + EXPECT_EQ(tensorDimOp0.source(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.source(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.index(), indexValue); }