diff --git a/paddle/cinn/common/cas.cc b/paddle/cinn/common/cas.cc index b72650301bbfe..6264c5b12d453 100644 --- a/paddle/cinn/common/cas.cc +++ b/paddle/cinn/common/cas.cc @@ -27,7 +27,6 @@ #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_visitor.h" -#include "paddle/cinn/optim/cast_simplify.h" #include "paddle/cinn/utils/string.h" namespace cinn { diff --git a/paddle/cinn/common/ir_util.cc b/paddle/cinn/common/ir_util.cc index 97d56f2ceaeb4..f0f219ee105f7 100644 --- a/paddle/cinn/common/ir_util.cc +++ b/paddle/cinn/common/ir_util.cc @@ -21,7 +21,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_printer.h" -#include "paddle/cinn/optim/cast_simplify.h" namespace cinn { namespace common { @@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector &shape, for (int i = 0; i < shape.size(); i++) { CHECK_EQ(shape[i].type(), Int(32)); Expr indice_prod = indices[i]; - optim::CastSimplify(&indice_prod); + optim::SimplifyCast(&indice_prod); for (int j = i + 1; j < shape.size(); j++) { indice_prod = RampRelatedMul(indice_prod, shape[j]); } diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 45c38c2632717..99ae9cf3bd3d6 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -23,7 +23,6 @@ gather_srcs( compute_inline_expand.cc buffer_assign.cc replace_const_param_to_integer.cc - cast_simplify.cc lower_intrin.cc cast_bool_to_int8.cc collect_undefined_vars.cc diff --git a/paddle/cinn/optim/cast_simplify.cc b/paddle/cinn/optim/cast_simplify.cc deleted file mode 100644 index a6431e42e467c..0000000000000 --- a/paddle/cinn/optim/cast_simplify.cc +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/cast_simplify.h" - -#include "paddle/cinn/ir/utils/ir_mutator.h" - -namespace cinn::optim { - -using cinn::common::bfloat16; -using cinn::common::float16; - -namespace { - -template -CastType NormCastValue(T value) { - if (type_of().is_uint() || type_of().is_uint()) { - // not support uint - return static_cast(value); - } - - if (std::isinf(value)) { - return std::numeric_limits::infinity(); - } else if (std::isnan(value)) { - return std::numeric_limits::signaling_NaN(); - } else if (value >= static_cast(std::numeric_limits::max())) { - return std::numeric_limits::max(); - } else if (value <= static_cast(std::numeric_limits::lowest())) { - return std::numeric_limits::lowest(); - } - return static_cast(value); -} - -struct Mutator : ir::IRMutator<> { - using ir::IRMutator<>::Visit; - - void Visit(const ir::Cast* op, Expr* expr) { - auto* node = expr->As(); - - Visit(&node->v(), &node->v()); - - if (op->type() == op->v().type()) { - *expr = op->v(); - return; - } - -#define __CAST_TO_TYPE(type__) \ - if (auto* i = op->v().As()) { \ - *expr = Expr(static_cast(i->value)); \ - } else if (auto* f = op->v().As()) { \ - *expr = Expr(static_cast(NormCastValue(f->value))); \ - } else if (auto* u = op->v().As()) { \ - *expr = Expr(static_cast(u->value)); \ - } else { \ - CINN_NOT_IMPLEMENTED \ - } - - if (op->v().is_constant()) { - if (op->type() == type_of()) { - __CAST_TO_TYPE(int8_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int16_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(int64_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint8_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint16_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint64_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(float) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(double) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(bool) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint32_t) - } else if (op->type() == type_of()) { - __CAST_TO_TYPE(uint64_t) - } else if (op->type() == type_of()) { - // Cannot simplify!!! pass - __CAST_TO_TYPE(bfloat16) - } else if (op->type() == type_of()) { - // Cannot simplify!!! pass - __CAST_TO_TYPE(float16) - } else { - CINN_NOT_IMPLEMENTED - } - } -#undef __CAST_TO_TYPE - } -}; - -} // namespace - -void CastSimplify(Expr* e) { - Mutator mutator; - mutator.Visit(e, e); -} - -} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_simplify.h b/paddle/cinn/optim/cast_simplify.h deleted file mode 100644 index 072f39783d187..0000000000000 --- a/paddle/cinn/optim/cast_simplify.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/cinn/ir/ir.h" - -namespace cinn::optim { - -/** - * Simplify the Cast nodes. - * - * There are several patterns: - * 1. the source and target type are the same, drop the Cast node - * 2. for intermediate numbers, just replace the Cast node with a Node of the - * target type - */ -void CastSimplify(Expr* e); - -} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_simplify_test.cc b/paddle/cinn/optim/cast_simplify_test.cc index 5b7a21bae86f4..57103db27f397 100644 --- a/paddle/cinn/optim/cast_simplify_test.cc +++ b/paddle/cinn/optim/cast_simplify_test.cc @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/optim/cast_simplify.h" - #include #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_printer.h" - +#include "paddle/cinn/optim/ir_simplify.h" namespace cinn::optim { TEST(CastSimplify, same_type) { @@ -26,7 +24,7 @@ TEST(CastSimplify, same_type) { Expr a = ir::Cast::Make(Int(32), n); LOG(INFO) << n->type(); LOG(INFO) << a; - CastSimplify(&a); + SimplifyCast(&a); ASSERT_EQ(utils::GetStreamCnt(a), "n"); } @@ -34,7 +32,7 @@ TEST(CastSimplify, Imm_int) { Expr a = ir::Cast::Make(Int(64), Expr(1)); Expr c = ir::Cast::Make(Int(32), a); LOG(INFO) << c; - CastSimplify(&c); + SimplifyCast(&c); LOG(INFO) << c; ASSERT_EQ(utils::GetStreamCnt(c), "1"); ASSERT_EQ(c.type(), Int(32)); @@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) { Expr a = ir::Cast::Make(Float(64), Expr(2.33)); Expr c = ir::Cast::Make(Int(32), a); LOG(INFO) << c; - CastSimplify(&c); + SimplifyCast(&c); LOG(INFO) << c; ASSERT_EQ(utils::GetStreamCnt(c), "2"); ASSERT_EQ(c.type(), Int(32)); @@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) { Expr a = ir::Cast::Make(UInt(64), Expr(1)); Expr c = ir::Cast::Make(UInt(32), a); LOG(INFO) << c; - CastSimplify(&c); + SimplifyCast(&c); LOG(INFO) << c; ASSERT_EQ(utils::GetStreamCnt(c), "1"); ASSERT_EQ(c.type(), UInt(32)); diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 3c0f9298dae2c..bfed498da521d 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -29,13 +29,14 @@ #include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_visitor.h" -#include "paddle/cinn/optim/cast_simplify.h" #include "paddle/cinn/utils/string.h" namespace cinn { namespace optim { using namespace ir; // NOLINT +using common::bfloat16; using common::ExprToGinacConverter; +using common::float16; using utils::GetStreamCnt; using utils::Replace; @@ -346,11 +347,95 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { } }; +template +CastType NormCastValue(T value) { + if (type_of().is_uint() || type_of().is_uint()) { + // not support uint + return static_cast(value); + } + + if (std::isinf(value)) { + return std::numeric_limits::infinity(); + } else if (std::isnan(value)) { + return std::numeric_limits::signaling_NaN(); + } else if (value >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if (value <= static_cast(std::numeric_limits::lowest())) { + return std::numeric_limits::lowest(); + } + return static_cast(value); +} + +struct SimplifyCastMutator : public ir::IRMutator<> { + void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } + + void Visit(const ir::Cast* op, Expr* expr) { + auto* node = expr->As(); + + ir::IRMutator::Visit(&node->v(), &node->v()); + + if (op->type() == op->v().type()) { + *expr = op->v(); + return; + } + +#define __CAST_TO_TYPE(type__) \ + if (auto* i = op->v().As()) { \ + *expr = Expr(static_cast(i->value)); \ + } else if (auto* f = op->v().As()) { \ + *expr = Expr(static_cast(NormCastValue(f->value))); \ + } else if (auto* u = op->v().As()) { \ + *expr = Expr(static_cast(u->value)); \ + } else { \ + CINN_NOT_IMPLEMENTED \ + } + + if (op->v().is_constant()) { + if (op->type() == type_of()) { + __CAST_TO_TYPE(int8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(float) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(double) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(bool) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(bfloat16) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(float16) + } else { + CINN_NOT_IMPLEMENTED + } + } +#undef __CAST_TO_TYPE + } +}; + } // namespace void Simplify(Expr* expr) { VLOG(3) << "Begin Simplify " << *expr; - optim::CastSimplify(expr); + SimplifyCastMutator()(expr); SimplifyRampMutator()(expr); SimplifyLoadMutator()(expr); SimplifyStoreMutator()(expr); @@ -363,6 +448,7 @@ void Simplify(Expr* expr) { ReplaceFracWithDivMutator()(expr); } +void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); } void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); } void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); } diff --git a/paddle/cinn/optim/ir_simplify.h b/paddle/cinn/optim/ir_simplify.h index 1d6abf1cd9a9f..b0e695af53537 100644 --- a/paddle/cinn/optim/ir_simplify.h +++ b/paddle/cinn/optim/ir_simplify.h @@ -30,6 +30,8 @@ namespace optim { */ void Simplify(Expr *expr); +void SimplifyCast(Expr *expr); + void SimplifyForLoops(Expr *expr); void SimplifyBlocks(Expr *expr); diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index 18db99df48e2b..b1e73e3c58a9b 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -19,7 +19,6 @@ #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/optim/call_arg_list_to_pod_value.h" #include "paddle/cinn/optim/cast_bool_to_int8.h" -#include "paddle/cinn/optim/cast_simplify.h" #include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h" #include "paddle/cinn/optim/extern_call_process.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h"