Skip to content

Commit

Permalink
[PIR]Add constraint function for dynamic shape (PaddlePaddle#57459)
Browse files Browse the repository at this point in the history
* build shape.

* add applyOpConstraint func
  • Loading branch information
liuruyan authored and iosmers committed Sep 21, 2023
1 parent b9057a5 commit 08947ac
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 42 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/utils.h"

Expand Down Expand Up @@ -66,7 +67,7 @@ void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {
if (auto tensor_type = type.dyn_cast<DenseTensorType>()) {
os << "tensor<";
for (auto d : phi::vectorize(tensor_type.dims())) {
os << d;
pir::ShapedTypeInterface::isDynamic(d) ? os << "?" : os << d;
os << "x";
}
tensor_type.dtype().Print(os);
Expand Down
8 changes: 8 additions & 0 deletions paddle/pir/dialect/shape/ir/shape_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ void ShapeDialect::initialize() {
TensorDimOp>();
}

void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const {
if (auto func_op = op->dyn_cast<FuncOp>()) {
func_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
}

} // namespace dialect
} // namespace pir

Expand Down
8 changes: 5 additions & 3 deletions paddle/pir/dialect/shape/ir/shape_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ namespace dialect {
///
/// \brief Shape Dialect:
///
class IR_API ShapeDialect : public pir::Dialect {
class IR_API ShapeDialect : public Dialect {
public:
explicit ShapeDialect(pir::IrContext *context);
explicit ShapeDialect(IrContext* context);
///
/// \brief Each Dialect needs to provide a name function to return the name of
/// the Dialect.
///
/// \return The name of this Dialect.
///
static const char *name() { return "shape"; }
static const char* name() { return "shape"; }
void PrintOperation(Operation* op,
IrPrinter& printer) const override; // NOLINT

private:
void initialize();
Expand Down
25 changes: 18 additions & 7 deletions paddle/pir/dialect/shape/ir/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ void TieProductEqualOp::Build(Builder &builder,
argument.AddInputs(rhs);
}

std::vector<Value> TieProductEqualOp::getLhs() {
std::vector<Value> TieProductEqualOp::lhs() {
int64_t lhs_len = attribute<Int64Attribute>("lhs_len").data();
std::vector<Value> res;
for (uint32_t idx = 0; idx < lhs_len; idx++) {
res.push_back(operand_source(idx));
}
return res;
}
std::vector<Value> TieProductEqualOp::getRhs() {
std::vector<Value> TieProductEqualOp::rhs() {
int64_t lhs_len = attribute<Int64Attribute>("lhs_len").data();
int64_t rhs_len = attribute<Int64Attribute>("rhs_len").data();
std::vector<Value> res;
Expand All @@ -211,9 +211,9 @@ void TieShapeOp::Build(Builder &builder, // NOLINT
argument.AddInputs(dims);
}

Value TieShapeOp::getValue() { return operand_source(0); }
Value TieShapeOp::value() { return operand_source(0); }

std::vector<Value> TieShapeOp::getShapeDimIndexes() {
std::vector<Value> TieShapeOp::dims() {
std::vector<Value> res;
for (uint32_t i = 1; i < num_operands(); i++) {
res.push_back(operand_source(i));
Expand All @@ -231,6 +231,17 @@ Block *FuncOp::block() {
return region.front();
}

void FuncOp::Print(IrPrinter &printer) {
auto &os = printer.os;
os << " shape.func () ";
os << "{";
for (auto item : *block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}

void TensorDimOp::Build(Builder &builder,
OperationArgument &argument,
Value source,
Expand All @@ -245,16 +256,16 @@ void TensorDimOp::Build(Builder &builder,
int64_t index) {
OpResult indexValue =
builder
.Build<ConstantOp>(Int64Attribute::get(IrContext::Instance(), 2),
.Build<ConstantOp>(Int64Attribute::get(IrContext::Instance(), index),
IndexType::get(IrContext::Instance()))
->result(0);
argument.AddInputs({source, indexValue});
argument.output_types.emplace_back(IndexType::get(IrContext::Instance()));
}

Value TensorDimOp::getSource() { return operand_source(0); }
Value TensorDimOp::source() { return operand_source(0); }

Value TensorDimOp::getIndex() { return operand_source(1); }
Value TensorDimOp::index() { return operand_source(1); }
} // namespace dialect
} // namespace pir

Expand Down
18 changes: 10 additions & 8 deletions paddle/pir/dialect/shape/ir/shape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"

namespace pir {
Expand Down Expand Up @@ -75,7 +76,7 @@ class IR_API DimOp : public Op<DimOp> {

const std::string getName();
void setName(std::string attrValue);
pir::OpResult out() { return result(0); }
OpResult out() { return result(0); }
void Verify() {}
};

Expand All @@ -96,8 +97,8 @@ class IR_API TieProductEqualOp : public Op<TieProductEqualOp> {
OperationArgument &argument, // NOLINT
const std::vector<Value> &lhs,
const std::vector<Value> &rhs);
std::vector<pir::Value> getLhs();
std::vector<pir::Value> getRhs();
std::vector<Value> lhs();
std::vector<Value> rhs();
void Verify() {}
};

Expand All @@ -117,8 +118,8 @@ class IR_API TieShapeOp : public Op<TieShapeOp> {
OperationArgument &argument, // NOLINT
Value input,
const std::vector<Value> &dims);
Value getValue();
std::vector<Value> getShapeDimIndexes();
Value value();
std::vector<Value> dims();
void Verify() {}
};

Expand All @@ -132,7 +133,8 @@ class IR_API FuncOp : public Op<FuncOp> {

static void Build(Builder &builder, // NOLINT
OperationArgument &argument); // NOLINT
pir::Block *block();
void Print(IrPrinter &printer); // NOLINT
Block *block();
void Verify() {}
};

Expand All @@ -152,8 +154,8 @@ class IR_API TensorDimOp : public Op<TensorDimOp> {
OperationArgument &argument, // NOLINT
Value source,
int64_t index);
Value getIndex();
Value getSource();
Value index();
Value source();
OpResult out() { return result(0); }
void Verify() {}
};
Expand Down
36 changes: 31 additions & 5 deletions paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pass/pass_registry.h"

namespace {
using PassPipelineRunner =
std::function<bool(pir::PassManager&, pir::ModuleOp)>;

bool InsertTieShapeOnValue(pir::Value value,
pir::Builder& builder) { // NOLINT
Expand All @@ -41,8 +45,9 @@ bool InsertTieShapeOnRegion(pir::Region* region);

bool InsertTieShapeOnOperation(pir::Operation* op,
pir::Builder& builder) { // NOLINT
if (op->isa<pir::dialect::TieShapeOp>()) return true;
// TODO(liujinnan): skip the specialized Ops.
// TODO(zhangbo63): skip more specialized Ops.
if (op->isa<pir::dialect::TieShapeOp>() || op->isa<pir::dialect::FuncOp>())
return true;

for (size_t i = 0; i < op->num_regions(); ++i) {
if (!InsertTieShapeOnRegion(&(op->region(i)))) return false;
Expand All @@ -55,7 +60,7 @@ bool InsertTieShapeOnOperation(pir::Operation* op,
return true;
}

bool insertTieShapeOnBlock(pir::Block* block) {
bool InsertTieShapeOnBlock(pir::Block* block) {
pir::Builder builder =
pir::Builder(pir::IrContext::Instance(), block, block->begin());
// TODO(liujinnan): mapping block arguments
Expand All @@ -70,7 +75,7 @@ bool insertTieShapeOnBlock(pir::Block* block) {

bool InsertTieShapeOnRegion(pir::Region* region) {
for (pir::Block* block : *region) {
if (!insertTieShapeOnBlock(block)) return false;
if (!InsertTieShapeOnBlock(block)) return false;
}
return true;
}
Expand All @@ -81,6 +86,20 @@ bool MaterializeShapeComputation(pir::ModuleOp m) {
return true;
}

bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) {
// TODO(liujinnan): Do some Canonicalizer.
pir::SymbolicDimMgr mgr(m);
IR_ENFORCE(mgr.Load(),
"SymbolicDimMgr Load failed in OptimizeShapeComputation.");
pir::ShapeComputationIRAnalysis analysis(m, mgr);
if (!analysis.Run()) {
return false;
}
IR_ENFORCE(mgr.Save(),
"SymbolicDimMgr save failed in OptimizeShapeComputation.");
return true;
}

class ShapeOptimizationPass : public pir::Pass {
public:
ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {}
Expand All @@ -89,10 +108,17 @@ class ShapeOptimizationPass : public pir::Pass {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op.");
MaterializeShapeComputation(module_op);
// runner is for Canonicalizer.
PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) {
return pm.Run(m.program());
};
if (!OptimizeShapeComputation(module_op, runner)) {
return;
}
}

bool CanApplyOn(pir::Operation* op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
};

Expand Down
Loading

0 comments on commit 08947ac

Please sign in to comment.