From 417acac77a6c63f706e8e942a410cad7d7b58362 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:03:56 +0200 Subject: [PATCH 01/11] Add `DerivativeOriginalFunctionBlock` and `DerivativeVisitor` --- src/language/code_generator.cmake | 1 + src/language/codegen.yaml | 25 +++- src/main.cpp | 8 ++ src/visitors/CMakeLists.txt | 1 + src/visitors/derivative_original_visitor.cpp | 129 +++++++++++++++++++ src/visitors/derivative_original_visitor.hpp | 64 +++++++++ src/visitors/sympy_solver_visitor.cpp | 4 + src/visitors/sympy_solver_visitor.hpp | 2 + 8 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 src/visitors/derivative_original_visitor.cpp create mode 100644 src/visitors/derivative_original_visitor.hpp diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index a3dea0767f..992d5b0cb1 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -74,6 +74,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 477df7fa65..ac92afb517 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -87,7 +87,30 @@ type: StatementBlock - finalize_block: brief: "Statement block to be executed after calling linear solver" - type: StatementBlock + type: StatementBlock + - DerivativeOriginalFunctionBlock: + nmodl: "DERIVATIVE_ORIGINAL_FUNCTION " + members: + - name: + brief: "Name of the derivative block" + type: Name + node_name: true + suffix: {value: " "} + - statement_block: + brief: "Block with statements vector" + type: StatementBlock + getter: {override: true} + brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + description: | + The original `DERIVATIVE` block in NMODL is + replaced in-place if the system of ODEs is + solvable analytically. Therefore, this + block's sole purpose is to keep the + original, unsolved block in the AST. This is + primarily useful when we need to solve the + ODE system using implicit methods, for + instance, CVode. + - WrappedExpression: brief: "Wrap any other expression type" members: diff --git a/src/main.cpp b/src/main.cpp index f12bfe35dd..f150753479 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,6 +25,7 @@ #include "visitors/after_cvode_to_cnexp_visitor.hpp" #include "visitors/ast_visitor.hpp" #include "visitors/constant_folder_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/function_callpath_visitor.hpp" #include "visitors/global_var_visitor.hpp" #include "visitors/implicit_argument_visitor.hpp" @@ -497,6 +498,13 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); + if (neuron_code) { + logger->info("Running derivative visitor"); + DerivativeOriginalVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("derivative_original")); + } + if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || sympy_linear) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index 262b6a623a..ede77671eb 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -11,6 +11,7 @@ add_library( visitor STATIC after_cvode_to_cnexp_visitor.cpp constant_folder_visitor.cpp + derivative_original_visitor.cpp defuse_analyze_visitor.cpp function_callpath_visitor.cpp global_var_visitor.cpp diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp new file mode 100644 index 0000000000..bd6d135350 --- /dev/null +++ b/src/visitors/derivative_original_visitor.cpp @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "visitors/derivative_original_visitor.hpp" + +#include "ast/all.hpp" +#include "lexer/token_mapping.hpp" +#include "pybind/pyembed.hpp" +#include "utils/logger.hpp" +#include "visitors/visitor_utils.hpp" +#include +#include + +namespace pywrap = nmodl::pybind_wrappers; + +namespace nmodl { +namespace visitor { + +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} + + +void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { + node.visit_children(*this); + der_block_function = node.clone(); +} + + +void DerivativeOriginalVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) { + derivative_block = true; + node_type = node.get_node_type(); + node.visit_children(*this); + node_type = ast::AstNodeType::NODE; + derivative_block = false; +} + +void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { + differential_equation = true; + node.visit_children(*this); + differential_equation = false; +} + + +void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + + /// we have to only solve ODEs under original derivative block where lhs is variable + if (!derivative_block || !differential_equation || !lhs->is_var_name()) { + return; + } + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + + if (name->is_prime_name() || name->is_indexed_name()) { + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + } else { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + // we discard the RHS here so it can be anything (as long as NMODL considers it valid) + auto statement = fmt::format("{} = {}", varname, varname); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + to_nmodl(node.get_lhs()), + varname, + to_nmodl(node)); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + // TODO add symbol? + } + } +} + +void DerivativeOriginalVisitor::visit_program(ast::Program& node) { + program_symtab = node.get_symbol_table(); + node.visit_children(*this); + if (der_block_function) { + auto der_node = + new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(), + der_block_function->get_statement_block()); + node.emplace_back_node(der_node); + } + + // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block + node.visit_children(*this); +} + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp new file mode 100644 index 0000000000..d483ab845b --- /dev/null +++ b/src/visitors/derivative_original_visitor.hpp @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * \file + * \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor + */ + +#include "symtab/decl.hpp" +#include "visitors/ast_visitor.hpp" +#include + +namespace nmodl { +namespace visitor { + +/** + * \addtogroup visitor_classes + * \{ + */ + +/** + * \class DerivativeOriginalVisitor + * \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as + * `DERIVATIVE_ORIGINAL_FUNCTION` block. + * + * If \ref SympySolverVisitor runs successfully, it replaces the original + * solution. This block is inserted before that to prevent losing access to + * information about the block. + */ +class DerivativeOriginalVisitor: public AstVisitor { + private: + /// The copy of the derivative block we are solving + ast::DerivativeBlock* der_block_function = nullptr; + + /// true while visiting differential equation + bool differential_equation = false; + + /// global symbol table + symtab::SymbolTable* program_symtab = nullptr; + + /// visiting derivative block + bool derivative_block = false; + + ast::AstNodeType node_type = ast::AstNodeType::NODE; + + public: + void visit_derivative_block(ast::DerivativeBlock& node) override; + void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; + void visit_diff_eq_expression(ast::DiffEqExpression& node) override; + void visit_binary_expression(ast::BinaryExpression& node) override; +}; + +/** \} */ // end of visitor_classes + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index f2d6260c21..e7b955a5c0 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -399,6 +399,10 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { } } +// Skip visiting DERIVATIVE_ORIGINAL block +void SympySolverVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) {} + void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index ecb326ab63..627451d4b7 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -185,6 +185,8 @@ class SympySolverVisitor: public AstVisitor { void visit_expression_statement(ast::ExpressionStatement& node) override; void visit_statement_block(ast::StatementBlock& node) override; void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; }; /** @} */ // end of visitor_classes From d33a594575291e4b7cc1fcb10f76501c17ec6143 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:18:33 +0200 Subject: [PATCH 02/11] Remove unused functions --- src/visitors/derivative_original_visitor.cpp | 26 -------------------- 1 file changed, 26 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index bd6d135350..2e7b6942a2 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -20,32 +20,6 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { -static int get_index(const ast::IndexedName& node) { - return std::stoi(to_nmodl(node.get_length())); -} - -static auto get_name_map(const ast::Expression& node, const std::string& name) { - std::unordered_map name_map; - // all of the "reserved" symbols - auto reserved_symbols = get_external_functions(); - // all indexed vars - auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); - for (const auto& var: indexed_vars) { - if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && - std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { - return var->get_node_name() == item; - })) { - logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " - "node_map", - var->get_node_name()); - name_map[var->get_node_name()] = get_index( - *std::dynamic_pointer_cast(var)); - } - } - return name_map; -} - void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); From b9f08d05225b37c9cd288c3471112cda51861bb4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:27:13 +0200 Subject: [PATCH 03/11] Add test for DerivativeOriginalVisitor --- test/unit/CMakeLists.txt | 1 + test/unit/visitor/derivative_original.cpp | 55 +++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 test/unit/visitor/derivative_original.cpp diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 44d57fe91f..9ed95d8aff 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -45,6 +45,7 @@ add_executable( visitor/kinetic_block.cpp visitor/localize.cpp visitor/localrename.cpp + visitor/derivative_original.cpp visitor/local_to_assigned.cpp visitor/lookup.cpp visitor/loop_unroll.cpp diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp new file mode 100644 index 0000000000..d2f5e17cf2 --- /dev/null +++ b/test/unit/visitor/derivative_original.cpp @@ -0,0 +1,55 @@ +#include + +#include "ast/program.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/unit/utils/test_utils.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/nmodl_visitor.hpp" +#include "visitors/symtab_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +using namespace nmodl; +using namespace visitor; +using namespace test; +using namespace test_utils; + +using nmodl::parser::NmodlDriver; + + +auto run_derivative_original_visitor(const std::string& text) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + SymtabVisitor().visit_program(*ast); + DerivativeOriginalVisitor().visit_program(*ast); + + return ast; +} + + +TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( + NEURON { + SUFFIX example + } + + STATE {x z[2]} + + DERIVATIVE equation { + x' = -x + z'[0] = x + z'[1] = x + z[0] + } +)"; + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); + } + } + } +} From c5dc45e9a3a63e2f4e63996d4254426f6e2f6fa6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:28:40 +0200 Subject: [PATCH 04/11] Fmt --- test/unit/visitor/derivative_original.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index d2f5e17cf2..58c0bd9c4f 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -4,9 +4,9 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/nmodl_visitor.hpp" #include "visitors/symtab_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" #include "visitors/visitor_utils.hpp" using namespace nmodl; @@ -28,8 +28,8 @@ auto run_derivative_original_visitor(const std::string& text) { TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { - GIVEN("DERIVATIVE block") { - std::string nmodl_text = R"( + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( NEURON { SUFFIX example } @@ -42,14 +42,15 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative z'[1] = x + z[0] } )"; - auto ast = run_derivative_original_visitor(nmodl_text); - THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { - auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); - REQUIRE(!block.empty()); - THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { - auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); - REQUIRE(primed_vars.empty()); - } + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, + {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); } } + } } From 1dadd7a21b43b97b9caec984e1fb48ea2cf001db Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 10:44:01 +0200 Subject: [PATCH 05/11] Fix leak --- src/visitors/derivative_original_visitor.cpp | 2 +- src/visitors/derivative_original_visitor.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 2e7b6942a2..9377641851 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -23,7 +23,7 @@ namespace visitor { void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block_function = node.clone(); + der_block_function = std::shared_ptr(node.clone()); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index d483ab845b..7178390ca8 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -36,7 +36,7 @@ namespace visitor { class DerivativeOriginalVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving - ast::DerivativeBlock* der_block_function = nullptr; + std::shared_ptr der_block_function = nullptr; /// true while visiting differential equation bool differential_equation = false; From 1125fdf43c886b9250faabd7dbb55241f610ecac Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:32:07 +0200 Subject: [PATCH 06/11] Remove unused stuff `DERIVATIVE` blocks can't have array variables in NOCMODL by default, so let's go with that. --- src/visitors/derivative_original_visitor.cpp | 41 ++++++-------------- src/visitors/derivative_original_visitor.hpp | 2 - test/unit/visitor/derivative_original.cpp | 7 ++-- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 9377641851..43b3eb6df8 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -30,9 +30,7 @@ void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& nod void DerivativeOriginalVisitor::visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) { derivative_block = true; - node_type = node.get_node_type(); node.visit_children(*this); - node_type = ast::AstNodeType::NODE; derivative_block = false; } @@ -53,34 +51,17 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n auto name = std::dynamic_pointer_cast(lhs)->get_name(); - if (name->is_prime_name() || name->is_indexed_name()) { - std::string varname; - if (name->is_prime_name()) { - varname = "D" + name->get_node_name(); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - name->get_node_name(), - varname, - to_nmodl(node)); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); - } - } else { - varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); - // we discard the RHS here so it can be anything (as long as NMODL considers it valid) - auto statement = fmt::format("{} = {}", varname, varname); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - to_nmodl(node.get_lhs()), - varname, - to_nmodl(node)); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); - // TODO add symbol? + if (name->is_prime_name()) { + auto varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); } } } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 7178390ca8..2fb3b26297 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -47,8 +47,6 @@ class DerivativeOriginalVisitor: public AstVisitor { /// visiting derivative block bool derivative_block = false; - ast::AstNodeType node_type = ast::AstNodeType::NODE; - public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 58c0bd9c4f..4533de36d5 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -34,12 +34,11 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative SUFFIX example } - STATE {x z[2]} + STATE {x z} DERIVATIVE equation { - x' = -x - z'[0] = x - z'[1] = x + z[0] + x' = -x + z * z + z' = z * x } )"; auto ast = run_derivative_original_visitor(nmodl_text); From 0267fbdf9c1f5a2fca08da574bd40a5d7949305c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:45:09 +0200 Subject: [PATCH 07/11] Update block description --- src/language/codegen.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index ac92afb517..e31cb4aca4 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -100,16 +100,15 @@ brief: "Block with statements vector" type: StatementBlock getter: {override: true} - brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars" description: | The original `DERIVATIVE` block in NMODL is replaced in-place if the system of ODEs is solvable analytically. Therefore, this - block's sole purpose is to keep the - original, unsolved block in the AST. This is - primarily useful when we need to solve the - ODE system using implicit methods, for - instance, CVode. + block's sole purpose is to keep the unsolved + block in the AST. This is primarily useful + when we need to solve the ODE system using + implicit methods, for instance, CVode. - WrappedExpression: brief: "Wrap any other expression type" From efdf0c0c6f12e3e9e318eddfb0556f4890f661e8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 15:30:40 +0200 Subject: [PATCH 08/11] Add `DerivativeOriginalJacobianBlock` --- python/nmodl/ode.py | 25 +++++-- src/language/code_generator.cmake | 1 + src/language/codegen.yaml | 22 ++++++ src/main.cpp | 17 ++--- src/pybind/wrapper.cpp | 40 ++++++++++- src/pybind/wrapper.hpp | 7 ++ src/visitors/derivative_original_visitor.cpp | 70 ++++++++++++++++++++ src/visitors/derivative_original_visitor.hpp | 5 ++ src/visitors/sympy_solver_visitor.cpp | 2 + src/visitors/sympy_solver_visitor.hpp | 2 + 10 files changed, 177 insertions(+), 14 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..2eab38e873 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } sympy_vars[dependent_var] = x # parse string into SymPy equation @@ -643,15 +648,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): # differentiate w.r.t. x diff = expr.diff(x).simplify() + # could be something generic like f'(x), in which case we use finite differences + if needs_finite_differences(diff): + diff = ( + transform_expression(diff, discretize_derivative) + .subs({finite_difference_step_variable(x): 1e-3}) + .evalf() + ) + + # the codegen method does not like undefined function calls, so we extract + # them here + custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} + # try to simplify expression in terms of existing variables # ignore any exceptions here, since we already have a valid solution # so if this further simplification step fails the error is not fatal try: # if expression is equal to one of the supplied vars, replace with this var # can do a simple string comparison here since a var cannot be further simplified - diff_as_string = sp.ccode(diff) + diff_as_string = sp.ccode(diff, user_functions=custom_fcts) for v in sympy_vars: - if diff_as_string == sp.ccode(sympy_vars[v]): + if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts): diff = sympy_vars[v] # or if equal to rhs of one of the supplied equations, replace with lhs @@ -672,4 +689,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): pass # return result as C code in NEURON format - return sp.ccode(diff.evalf()) + return sp.ccode(diff.evalf(), user_functions=custom_fcts) diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index 992d5b0cb1..7e6f827506 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -75,6 +75,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/derivative_original_jacobian_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index e31cb4aca4..33b77e4eb0 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -109,6 +109,28 @@ block in the AST. This is primarily useful when we need to solve the ODE system using implicit methods, for instance, CVode. + - DerivativeOriginalJacobianBlock: + nmodl: "DERIVATIVE_ORIGINAL_JACOBIAN " + members: + - name: + brief: "Name of the derivative block" + type: Name + node_name: true + suffix: {value: " "} + - statement_block: + brief: "Block with statements vector" + type: StatementBlock + getter: {override: true} + brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + description: | + The original `DERIVATIVE` block in NMODL is + replaced in-place if the system of ODEs is + solvable analytically. Therefore, this + block's sole purpose is to keep the + original, unsolved block in the AST. This is + primarily useful when we need to solve the + ODE system using implicit methods, for + instance, CVode. - WrappedExpression: brief: "Wrap any other expression type" diff --git a/src/main.cpp b/src/main.cpp index f150753479..7513997697 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -498,18 +498,19 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); - if (neuron_code) { - logger->info("Running derivative visitor"); - DerivativeOriginalVisitor().visit_program(*ast); - SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("derivative_original")); - } - if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || - sympy_linear) { + sympy_linear || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); + + if (neuron_code) { + logger->info("Running derivative visitor"); + DerivativeOriginalVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("derivative_original")); + } + if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 32c390736c..ae9d414976 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -9,7 +9,7 @@ #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" - +#include #include #include @@ -186,6 +186,41 @@ except Exception as e: return {std::move(solution), std::move(exception_message)}; } +/// \brief A blunt instrument that differentiates expression w.r.t. variable +/// \return The tuple (solution, exception) +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars) { + std::string statements; + // only indexed variables require special treatment + for (const auto& [var, prop]: indexed_vars) { + statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); + } + auto locals = py::dict("expression"_a = expression, "variable"_a = variable); + std::string script = fmt::format(R"( +_allvars = [] +{} +exception_message = "" +try: + solution = differentiate2c(expression, + variable, + _allvars, + ) +except Exception as e: + # if we fail, fail silently and return empty string + solution = "" + exception_message = str(e) +)", + statements); + + py::exec(nmodl::pybind_wrappers::ode_py + script, locals); + + auto solution = locals["solution"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solution), std::move(exception_message)}; +} void initialize_interpreter_func() { pybind11::initialize_interpreter(true); @@ -203,7 +238,8 @@ NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept { &call_solve_nonlinear_system, &call_solve_linear_system, &call_diffeq_solver, - &call_analytic_diff}; + &call_analytic_diff, + &call_diff2c}; } } diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index 725f9f8113..b4ec0a2dff 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace nmodl { @@ -44,6 +45,11 @@ std::tuple call_analytic_diff( const std::vector& expressions, const std::set& used_names_in_block); +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars = {}); + struct pybind_wrap_api { decltype(&initialize_interpreter_func) initialize_interpreter; decltype(&finalize_interpreter_func) finalize_interpreter; @@ -51,6 +57,7 @@ struct pybind_wrap_api { decltype(&call_solve_linear_system) solve_linear_system; decltype(&call_diffeq_solver) diffeq_solver; decltype(&call_analytic_diff) analytic_diff; + decltype(&call_diff2c) diff2c; }; #ifdef _WIN32 diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 43b3eb6df8..1625cdc52e 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -20,17 +20,55 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} + void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); der_block_function = std::shared_ptr(node.clone()); + der_block_jacobian = std::shared_ptr(node.clone()); } void DerivativeOriginalVisitor::visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) { derivative_block = true; + node_type = node.get_node_type(); node.visit_children(*this); + node_type = ast::AstNodeType::NODE; + derivative_block = false; +} + +void DerivativeOriginalVisitor::visit_derivative_original_jacobian_block( + ast::DerivativeOriginalJacobianBlock& node) { + derivative_block = true; + node_type = node.get_node_type(); + node.visit_children(*this); + node_type = ast::AstNodeType::NODE; derivative_block = false; } @@ -63,6 +101,32 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } + if (node_type == ast::AstNodeType::DERIVATIVE_ORIGINAL_JACOBIAN_BLOCK) { + logger->debug( + "DerivativeOriginalVisitor :: visiting expr {} in DERIVATIVE_ORIGINAL_JACOBIAN", + to_nmodl(node)); + auto rhs = node.get_rhs(); + // map of all indexed symbols (need special treatment in SymPy) + auto name_map = get_name_map(*rhs, name->get_node_name()); + auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; + auto [jacobian, + exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); + if (!exception_message.empty()) { + logger->warn("DerivativeOriginalVisitor :: python exception: {}", + exception_message); + } + // NOTE: LHS can be anything here, the equality is to keep `create_statement` from + // complaining, we discard the LHS later + auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); + logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + to_nmodl(node), + statement); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + } } } @@ -75,6 +139,12 @@ void DerivativeOriginalVisitor::visit_program(ast::Program& node) { der_block_function->get_statement_block()); node.emplace_back_node(der_node); } + if (der_block_jacobian) { + auto der_node = + new ast::DerivativeOriginalJacobianBlock(der_block_jacobian->get_name(), + der_block_jacobian->get_statement_block()); + node.emplace_back_node(der_node); + } // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block node.visit_children(*this); diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 2fb3b26297..e268e6557f 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -37,6 +37,7 @@ class DerivativeOriginalVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving std::shared_ptr der_block_function = nullptr; + std::shared_ptr der_block_jacobian = nullptr; /// true while visiting differential equation bool differential_equation = false; @@ -47,11 +48,15 @@ class DerivativeOriginalVisitor: public AstVisitor { /// visiting derivative block bool derivative_block = false; + ast::AstNodeType node_type = ast::AstNodeType::NODE; + public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; void visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) override; + void visit_derivative_original_jacobian_block( + ast::DerivativeOriginalJacobianBlock& node) override; void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; }; diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index e7b955a5c0..fc8bba8a8f 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -402,6 +402,8 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { // Skip visiting DERIVATIVE_ORIGINAL block void SympySolverVisitor::visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) {} +void SympySolverVisitor::visit_derivative_original_jacobian_block( + ast::DerivativeOriginalJacobianBlock& node) {} void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index 627451d4b7..7fa46492e3 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -187,6 +187,8 @@ class SympySolverVisitor: public AstVisitor { void visit_program(ast::Program& node) override; void visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) override; + void visit_derivative_original_jacobian_block( + ast::DerivativeOriginalJacobianBlock& node) override; }; /** @} */ // end of visitor_classes From 5fb7bfa87e5cb3de5734c40a42a6eaffec94814c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 15:33:42 +0200 Subject: [PATCH 09/11] Add test for Jacobian --- test/unit/visitor/derivative_original.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 4533de36d5..774b3762d5 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -51,5 +51,14 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative REQUIRE(primed_vars.empty()); } } + THEN("DERIVATIVE_ORIGINAL_JACOBIAN block is added") { + auto block = collect_nodes(*ast, + {ast::AstNodeType::DERIVATIVE_ORIGINAL_JACOBIAN_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_JACOBIAN block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); + } + } } } From 507056165b0196ebae5ca65f0da80f86ab0791c4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 15:36:09 +0200 Subject: [PATCH 10/11] Fix wording of block --- src/language/codegen.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 33b77e4eb0..24d75a1230 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -121,16 +121,15 @@ brief: "Block with statements vector" type: StatementBlock getter: {override: true} - brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars and RHS by D vars / (1 - dt * J(vars))" description: | The original `DERIVATIVE` block in NMODL is replaced in-place if the system of ODEs is solvable analytically. Therefore, this - block's sole purpose is to keep the - original, unsolved block in the AST. This is - primarily useful when we need to solve the - ODE system using implicit methods, for - instance, CVode. + block's sole purpose is to keep the unsolved + block in the AST. This is primarily useful + when we need to solve the ODE system using + implicit methods, for instance, CVode. - WrappedExpression: brief: "Wrap any other expression type" From cbc3017803db3f7f5edd809024844150d1bb2661 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 15:41:12 +0200 Subject: [PATCH 11/11] Can I run the CI now pls? Thanks