Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DerivativeOriginalJacobianBlock #1471

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
vars = set(vars)
vars.discard(dependent_var)
# declare all other supplied variables
sympy_vars = {var: sp.symbols(var, real=True) for var in vars}
sympy_vars = {
var if isinstance(var, str) else str(var): (
sp.symbols(var, real=True) if isinstance(var, str) else var
)
for var in vars
}
sympy_vars[dependent_var] = x

# parse string into SymPy equation
Expand Down Expand Up @@ -643,15 +648,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
# differentiate w.r.t. x
diff = expr.diff(x).simplify()

# could be something generic like f'(x), in which case we use finite differences
if needs_finite_differences(diff):
diff = (
transform_expression(diff, discretize_derivative)
.subs({finite_difference_step_variable(x): 1e-3})
.evalf()
)

# the codegen method does not like undefined function calls, so we extract
# them here
custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)}

# try to simplify expression in terms of existing variables
# ignore any exceptions here, since we already have a valid solution
# so if this further simplification step fails the error is not fatal
try:
# if expression is equal to one of the supplied vars, replace with this var
# can do a simple string comparison here since a var cannot be further simplified
diff_as_string = sp.ccode(diff)
diff_as_string = sp.ccode(diff, user_functions=custom_fcts)
for v in sympy_vars:
if diff_as_string == sp.ccode(sympy_vars[v]):
if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts):
diff = sympy_vars[v]

# or if equal to rhs of one of the supplied equations, replace with lhs
Expand All @@ -672,4 +689,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None):
pass

# return result as C code in NEURON format
return sp.ccode(diff.evalf())
return sp.ccode(diff.evalf(), user_functions=custom_fcts)
2 changes: 2 additions & 0 deletions src/language/code_generator.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ set(AST_GENERATED_SOURCES
${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp
${PROJECT_BINARY_DIR}/src/ast/define.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_original_jacobian_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp
${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp
${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp
Expand Down
45 changes: 44 additions & 1 deletion src/language/codegen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,50 @@
type: StatementBlock
- finalize_block:
brief: "Statement block to be executed after calling linear solver"
type: StatementBlock
type: StatementBlock
- DerivativeOriginalFunctionBlock:
nmodl: "DERIVATIVE_ORIGINAL_FUNCTION "
members:
- name:
brief: "Name of the derivative block"
type: Name
node_name: true
suffix: {value: " "}
- statement_block:
brief: "Block with statements vector"
type: StatementBlock
getter: {override: true}
brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars"
description: |
The original `DERIVATIVE` block in NMODL is
replaced in-place if the system of ODEs is
solvable analytically. Therefore, this
block's sole purpose is to keep the unsolved
block in the AST. This is primarily useful
when we need to solve the ODE system using
implicit methods, for instance, CVode.
- DerivativeOriginalJacobianBlock:
nmodl: "DERIVATIVE_ORIGINAL_JACOBIAN "
members:
- name:
brief: "Name of the derivative block"
type: Name
node_name: true
suffix: {value: " "}
- statement_block:
brief: "Block with statements vector"
type: StatementBlock
getter: {override: true}
brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars and RHS by D vars / (1 - dt * J(vars))"
description: |
The original `DERIVATIVE` block in NMODL is
replaced in-place if the system of ODEs is
solvable analytically. Therefore, this
block's sole purpose is to keep the unsolved
block in the AST. This is primarily useful
when we need to solve the ODE system using
implicit methods, for instance, CVode.

- WrappedExpression:
brief: "Wrap any other expression type"
members:
Expand Down
11 changes: 10 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "visitors/after_cvode_to_cnexp_visitor.hpp"
#include "visitors/ast_visitor.hpp"
#include "visitors/constant_folder_visitor.hpp"
#include "visitors/derivative_original_visitor.hpp"
#include "visitors/function_callpath_visitor.hpp"
#include "visitors/global_var_visitor.hpp"
#include "visitors/implicit_argument_visitor.hpp"
Expand Down Expand Up @@ -498,10 +499,18 @@ int run_nmodl(int argc, const char* argv[]) {
const bool sympy_sparse = solver_exists(*ast, "sparse");

if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit ||
sympy_linear) {
sympy_linear || neuron_code) {
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance()
.api()
.initialize_interpreter();

if (neuron_code) {
logger->info("Running derivative visitor");
DerivativeOriginalVisitor().visit_program(*ast);
SymtabVisitor(update_symtab).visit_program(*ast);
ast_to_nmodl(*ast, filepath("derivative_original"));
}

if (sympy_conductance) {
logger->info("Running sympy conductance visitor");
SympyConductanceVisitor().visit_program(*ast);
Expand Down
40 changes: 38 additions & 2 deletions src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include "codegen/codegen_naming.hpp"
#include "pybind/pyembed.hpp"

#include <fmt/format.h>
#include <pybind11/embed.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -186,6 +186,41 @@ except Exception as e:
return {std::move(solution), std::move(exception_message)};
}

/// \brief A blunt instrument that differentiates expression w.r.t. variable
/// \return The tuple (solution, exception)
std::tuple<std::string, std::string> call_diff2c(
const std::string& expression,
const std::string& variable,
const std::unordered_map<std::string, int>& indexed_vars) {
std::string statements;
// only indexed variables require special treatment
for (const auto& [var, prop]: indexed_vars) {
statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var);
}
auto locals = py::dict("expression"_a = expression, "variable"_a = variable);
std::string script = fmt::format(R"(
_allvars = []
{}
exception_message = ""
try:
solution = differentiate2c(expression,
variable,
_allvars,
)
except Exception as e:
# if we fail, fail silently and return empty string
solution = ""
exception_message = str(e)
)",
statements);

py::exec(nmodl::pybind_wrappers::ode_py + script, locals);

auto solution = locals["solution"].cast<std::string>();
auto exception_message = locals["exception_message"].cast<std::string>();

return {std::move(solution), std::move(exception_message)};
}

void initialize_interpreter_func() {
pybind11::initialize_interpreter(true);
Expand All @@ -203,7 +238,8 @@ NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept {
&call_solve_nonlinear_system,
&call_solve_linear_system,
&call_diffeq_solver,
&call_analytic_diff};
&call_analytic_diff,
&call_diff2c};
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/pybind/wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <set>
#include <string>
#include <unordered_map>
#include <vector>

namespace nmodl {
Expand Down Expand Up @@ -44,13 +45,19 @@ std::tuple<std::string, std::string> call_analytic_diff(
const std::vector<std::string>& expressions,
const std::set<std::string>& used_names_in_block);

std::tuple<std::string, std::string> call_diff2c(
const std::string& expression,
const std::string& variable,
const std::unordered_map<std::string, int>& indexed_vars = {});

struct pybind_wrap_api {
decltype(&initialize_interpreter_func) initialize_interpreter;
decltype(&finalize_interpreter_func) finalize_interpreter;
decltype(&call_solve_nonlinear_system) solve_nonlinear_system;
decltype(&call_solve_linear_system) solve_linear_system;
decltype(&call_diffeq_solver) diffeq_solver;
decltype(&call_analytic_diff) analytic_diff;
decltype(&call_diff2c) diff2c;
};

#ifdef _WIN32
Expand Down
1 change: 1 addition & 0 deletions src/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_library(
visitor STATIC
after_cvode_to_cnexp_visitor.cpp
constant_folder_visitor.cpp
derivative_original_visitor.cpp
defuse_analyze_visitor.cpp
function_callpath_visitor.cpp
global_var_visitor.cpp
Expand Down
154 changes: 154 additions & 0 deletions src/visitors/derivative_original_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright 2023 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include "visitors/derivative_original_visitor.hpp"

#include "ast/all.hpp"
#include "lexer/token_mapping.hpp"
#include "pybind/pyembed.hpp"
#include "utils/logger.hpp"
#include "visitors/visitor_utils.hpp"
#include <optional>
#include <utility>

namespace pywrap = nmodl::pybind_wrappers;

namespace nmodl {
namespace visitor {

static int get_index(const ast::IndexedName& node) {
return std::stoi(to_nmodl(node.get_length()));
}

static auto get_name_map(const ast::Expression& node, const std::string& name) {
std::unordered_map<std::string, int> name_map;
// all of the "reserved" symbols
auto reserved_symbols = get_external_functions();
// all indexed vars
auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME});
for (const auto& var: indexed_vars) {
if (!name_map.count(var->get_node_name()) && var->get_node_name() != name &&
std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) {
return var->get_node_name() == item;
})) {
logger->debug(
"DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to "
"node_map",
var->get_node_name());
name_map[var->get_node_name()] = get_index(
*std::dynamic_pointer_cast<const ast::IndexedName>(var));
}
}
return name_map;
}


void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) {
node.visit_children(*this);
der_block_function = std::shared_ptr<ast::DerivativeBlock>(node.clone());
der_block_jacobian = std::shared_ptr<ast::DerivativeBlock>(node.clone());
}


void DerivativeOriginalVisitor::visit_derivative_original_function_block(
ast::DerivativeOriginalFunctionBlock& node) {
derivative_block = true;
node_type = node.get_node_type();
node.visit_children(*this);
node_type = ast::AstNodeType::NODE;
derivative_block = false;
}

void DerivativeOriginalVisitor::visit_derivative_original_jacobian_block(
ast::DerivativeOriginalJacobianBlock& node) {
derivative_block = true;
node_type = node.get_node_type();
node.visit_children(*this);
node_type = ast::AstNodeType::NODE;
derivative_block = false;
}

void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) {
differential_equation = true;
node.visit_children(*this);
differential_equation = false;
}


void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) {
const auto& lhs = node.get_lhs();

/// we have to only solve ODEs under original derivative block where lhs is variable
if (!derivative_block || !differential_equation || !lhs->is_var_name()) {
return;
}

auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();

if (name->is_prime_name()) {
auto varname = "D" + name->get_node_name();
logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}",
name->get_node_name(),
varname,
to_nmodl(node));
node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
if (program_symtab->lookup(varname) == nullptr) {
auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
symbol->set_original_name(name->get_node_name());
program_symtab->insert(symbol);
}
if (node_type == ast::AstNodeType::DERIVATIVE_ORIGINAL_JACOBIAN_BLOCK) {
logger->debug(
"DerivativeOriginalVisitor :: visiting expr {} in DERIVATIVE_ORIGINAL_JACOBIAN",
to_nmodl(node));
auto rhs = node.get_rhs();
// map of all indexed symbols (need special treatment in SymPy)
auto name_map = get_name_map(*rhs, name->get_node_name());
auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c;
auto [jacobian,
exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map);
if (!exception_message.empty()) {
logger->warn("DerivativeOriginalVisitor :: python exception: {}",
exception_message);
}
// NOTE: LHS can be anything here, the equality is to keep `create_statement` from
// complaining, we discard the LHS later
auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian);
logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}",
to_nmodl(node),
statement);
auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
create_statement(statement));
const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
expr_statement->get_expression());
node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
}
}
}

void DerivativeOriginalVisitor::visit_program(ast::Program& node) {
program_symtab = node.get_symbol_table();
node.visit_children(*this);
if (der_block_function) {
auto der_node =
new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(),
der_block_function->get_statement_block());
node.emplace_back_node(der_node);
}
if (der_block_jacobian) {
auto der_node =
new ast::DerivativeOriginalJacobianBlock(der_block_jacobian->get_name(),
der_block_jacobian->get_statement_block());
node.emplace_back_node(der_node);
}

// re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block
node.visit_children(*this);
}

} // namespace visitor
} // namespace nmodl
Loading
Loading