From 64cc53360f02628c470dabbca111417211347a40 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 12 Apr 2024 16:18:02 +0800 Subject: [PATCH 1/2] [Feature] Support config.force_scalarize_matrix to avoid perf-regression in certain scenario --- taichi/analysis/offline_cache_util.cpp | 1 + taichi/codegen/codegen_utils.h | 2 +- taichi/program/compile_config.cpp | 1 + taichi/program/compile_config.h | 1 + taichi/python/export_lang.cpp | 2 + taichi/transforms/auto_diff.cpp | 51 ++++++++++++++++++----- taichi/transforms/compile_to_offloads.cpp | 13 +++++- taichi/transforms/simplify.cpp | 5 ++- 8 files changed, 60 insertions(+), 16 deletions(-) diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 20bcb84df44ec..2ea4d0df979da 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -78,6 +78,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config.experimental_auto_mesh_local); serializer(config.auto_mesh_local_default_occupacy); serializer(config.real_matrix_scalarize); + serializer(config.force_scalarize_matrix); serializer(config.half2_vectorization); serializer.finalize(); diff --git a/taichi/codegen/codegen_utils.h b/taichi/codegen/codegen_utils.h index 16ebaf0fa3328..fa96d05de5969 100644 --- a/taichi/codegen/codegen_utils.h +++ b/taichi/codegen/codegen_utils.h @@ -5,7 +5,7 @@ namespace taichi::lang { inline bool codegen_vector_type(const CompileConfig &config) { - return !config.real_matrix_scalarize; + return !(config.real_matrix_scalarize || config.force_scalarize_matrix); } // Parses a C-style printf format string specifier into its constituent parts. diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index bb6c102041d7d..8c29da201184d 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -43,6 +43,7 @@ CompileConfig::CompileConfig() { make_block_local = true; detect_read_only = true; real_matrix_scalarize = true; + force_scalarize_matrix = true; half2_vectorization = false; make_cpu_multithreading_loop = true; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 9290a54969c60..87832121f04d6 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -39,6 +39,7 @@ struct CompileConfig { bool make_block_local; bool detect_read_only; bool real_matrix_scalarize; + bool force_scalarize_matrix; bool half2_vectorization; bool make_cpu_multithreading_loop; DataType default_fp; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index cfdc8e987ae2d..867f4e785a841 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -218,6 +218,8 @@ void export_lang(py::module &m) { .def_readwrite("detect_read_only", &CompileConfig::detect_read_only) .def_readwrite("real_matrix_scalarize", &CompileConfig::real_matrix_scalarize) + .def_readwrite("force_scalarize_matrix", + &CompileConfig::force_scalarize_matrix) .def_readwrite("half2_vectorization", &CompileConfig::half2_vectorization) .def_readwrite("make_cpu_multithreading_loop", &CompileConfig::make_cpu_multithreading_loop) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index d1d35c8893b5b..9525ffd981c8a 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -82,8 +82,13 @@ class IndependentBlocksJudger : public BasicStmtVisitor { if (is_inside_loop_) return; - if (stmt->dest->is()) { - if (stmt->dest->as() + Stmt *dest = stmt->dest; + if (dest->is()) { + dest = dest->as()->origin; + } + + if (dest->is()) { + if (dest->as() ->base_ptr->as() ->ret_type.ptr_removed() ->as() @@ -92,8 +97,8 @@ class IndependentBlocksJudger : public BasicStmtVisitor { qualified_glb_operations_ = true; } } else { - TI_ASSERT(stmt->dest->is()); - if (stmt->dest->as()->snode->has_adjoint()) { + TI_ASSERT(dest->is()); + if (dest->as()->snode->has_adjoint()) { qualified_glb_operations_ = true; } } @@ -108,15 +113,21 @@ class IndependentBlocksJudger : public BasicStmtVisitor { // another IndependentBlocksJudger if (is_inside_loop_) return; - if ((stmt->src->is() && - stmt->src->as() + + Stmt *src = stmt->src; + if (src->is()) { + src = src->as()->origin; + } + + if ((src->is() && + src->as() ->base_ptr->as() ->ret_type.ptr_removed() ->as() ->elements() .size() > TypeFactory::GRAD_PTR_POS_IN_NDARRAY) || - (stmt->src->is() && - stmt->src->as()->snode->has_adjoint())) { + (src->is() && + src->as()->snode->has_adjoint())) { qualified_glb_operations_ = true; } } @@ -2425,7 +2436,13 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { using BasicStmtVisitor::visit; void visit(GlobalLoadStmt *stmt) override { - GlobalPtrStmt *src = stmt->src->as(); + GlobalPtrStmt *src = nullptr; + if (stmt->src->is()) { + src = stmt->src->as(); + } else { + TI_ASSERT(stmt->src->is()); + src = stmt->src->as()->origin->as(); + } auto snode = src->snode; if (!snode->has_adjoint_checkbit()) { return; @@ -2466,12 +2483,24 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { } void visit(GlobalStoreStmt *stmt) override { - GlobalPtrStmt *dest = stmt->dest->as(); + GlobalPtrStmt *dest = nullptr; + if (stmt->dest->is()) { + dest = stmt->dest->as(); + } else { + TI_ASSERT(stmt->dest->is()); + dest = stmt->dest->as()->origin->as(); + } visit_gloabl_store_stmt_and_atomic_add(stmt, dest); } void visit(AtomicOpStmt *stmt) override { - GlobalPtrStmt *dest = stmt->dest->as(); + GlobalPtrStmt *dest = nullptr; + if (stmt->dest->is()) { + dest = stmt->dest->as(); + } else { + TI_ASSERT(stmt->dest->is()); + dest = stmt->dest->as()->origin->as(); + } visit_gloabl_store_stmt_and_atomic_add(stmt, dest); } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 1a3198bbef195..0507b9fd52344 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -86,6 +86,10 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::gather_meshfor_relation_types(ir); } + if (config.force_scalarize_matrix) { + irpass::scalarize(ir, false /*half2_optimization_enabled*/); + } + if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) { // Check whether the kernel obeys the autodiff limitation e.g., gloabl data // access rule @@ -136,8 +140,9 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { - irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false, - !config.real_matrix_scalarize); + irpass::cfg_optimization( + ir, false, /*autodiff_enabled*/ false, + !config.real_matrix_scalarize && !config.force_scalarize_matrix); print("Optimized by CFG"); irpass::analysis::verify(ir); } @@ -371,6 +376,10 @@ void compile_function(IRNode *ir, func->set_ir_stage(Function::IRStage::BeforeLowerAccess); } + if (config.force_scalarize_matrix) { + irpass::scalarize(ir, false /*half2_optimization_enabled*/); + } + if (target_stage >= Function::IRStage::OptimizedIR && current_stage < Function::IRStage::OptimizedIR) { irpass::lower_access(ir, config, {{}, true}); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 550ea8220303b..c11551682e430 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -564,8 +564,9 @@ void full_simplify(IRNode *root, // Don't do this time-consuming optimization pass again if the IR is // not modified. if (config.opt_level > 0 && first_iteration && config.cfg_optimization && - cfg_optimization(root, args.after_lower_access, args.autodiff_enabled, - !config.real_matrix_scalarize)) + cfg_optimization( + root, args.after_lower_access, args.autodiff_enabled, + !config.real_matrix_scalarize && !config.force_scalarize_matrix)) modified = true; print("cfg_optimization"); first_iteration = false; From 4a67b5315440f7442b3d1653cc0a39e4aa6ce251 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 12 Apr 2024 16:19:12 +0800 Subject: [PATCH 2/2] bug fix --- taichi/program/compile_config.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 8c29da201184d..89359c8bb7920 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -43,7 +43,7 @@ CompileConfig::CompileConfig() { make_block_local = true; detect_read_only = true; real_matrix_scalarize = true; - force_scalarize_matrix = true; + force_scalarize_matrix = false; half2_vectorization = false; make_cpu_multithreading_loop = true;