Skip to content

Commit

Permalink
【CINN】Integate cast_simplify into ir_simplify (#56958)
Browse files Browse the repository at this point in the history
* integate cast_simplify into ir_simplify

* fix cast simplify testcase
  • Loading branch information
Courtesy-Xs authored Sep 8, 2023
1 parent fc71459 commit ad0e7a2
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 162 deletions.
1 change: 0 additions & 1 deletion paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"

namespace cinn {
namespace common {
Expand Down Expand Up @@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i];
optim::CastSimplify(&indice_prod);
optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
indice_prod = RampRelatedMul(indice_prod, shape[j]);
}
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ gather_srcs(
compute_inline_expand.cc
buffer_assign.cc
replace_const_param_to_integer.cc
cast_simplify.cc
lower_intrin.cc
cast_bool_to_int8.cc
collect_undefined_vars.cc
Expand Down
117 changes: 0 additions & 117 deletions paddle/cinn/optim/cast_simplify.cc

This file was deleted.

31 changes: 0 additions & 31 deletions paddle/cinn/optim/cast_simplify.h

This file was deleted.

12 changes: 5 additions & 7 deletions paddle/cinn/optim/cast_simplify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/cast_simplify.h"

#include <gtest/gtest.h>

#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"

#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn::optim {

TEST(CastSimplify, same_type) {
Var n("n");
Expr a = ir::Cast::Make(Int(32), n);
LOG(INFO) << n->type();
LOG(INFO) << a;
CastSimplify(&a);
SimplifyCast(&a);
ASSERT_EQ(utils::GetStreamCnt(a), "n");
}

TEST(CastSimplify, Imm_int) {
Expr a = ir::Cast::Make(Int(64), Expr(1));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), Int(32));
Expand All @@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) {
Expr a = ir::Cast::Make(Float(64), Expr(2.33));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "2");
ASSERT_EQ(c.type(), Int(32));
Expand All @@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) {
Expr a = ir::Cast::Make(UInt(64), Expr(1));
Expr c = ir::Cast::Make(UInt(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), UInt(32));
Expand Down
90 changes: 88 additions & 2 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
namespace optim {
using namespace ir; // NOLINT
using common::bfloat16;
using common::ExprToGinacConverter;
using common::float16;
using utils::GetStreamCnt;
using utils::Replace;

Expand Down Expand Up @@ -346,11 +347,95 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
}
};

template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}

if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}

struct SimplifyCastMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<ir::Expr*>::Visit(expr, expr); }

void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();

ir::IRMutator<ir::Expr*>::Visit(&node->v(), &node->v());

if (op->type() == op->v().type()) {
*expr = op->v();
return;
}

#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}

if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};

} // namespace

void Simplify(Expr* expr) {
VLOG(3) << "Begin Simplify " << *expr;
optim::CastSimplify(expr);
SimplifyCastMutator()(expr);
SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr);
Expand All @@ -363,6 +448,7 @@ void Simplify(Expr* expr) {
ReplaceFracWithDivMutator()(expr);
}

void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/optim/ir_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace optim {
*/
void Simplify(Expr *expr);

void SimplifyCast(Expr *expr);

void SimplifyForLoops(Expr *expr);

void SimplifyBlocks(Expr *expr);
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/call_arg_list_to_pod_value.h"
#include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
Expand Down

0 comments on commit ad0e7a2

Please sign in to comment.