Skip to content

Commit

Permalink
Implement v via an instance specific copy.
Browse files Browse the repository at this point in the history
1uc committed Oct 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 750ec88 commit 804adcd
Showing 5 changed files with 131 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
@@ -795,7 +795,10 @@ std::string CodegenNeuronCppVisitor::global_variable_name(const SymbolType& symb

std::string CodegenNeuronCppVisitor::get_variable_name(const std::string& name,
bool use_instance) const {
const std::string& varname = update_if_ion_variable_name(name);
std::string varname = update_if_ion_variable_name(name);
if (!info.artificial_cell && varname == "v") {
varname = naming::VOLTAGE_UNUSED_VARIABLE;
}

auto name_comparator = [&varname](const auto& sym) { return varname == get_name(sym); };

@@ -956,9 +959,6 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool /* print_initializers */)

CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::functor_params() {
auto params = internal_method_parameters();
if (!info.artificial_cell) {
params.push_back({"", "double", "", "v"});
}

return params;
}
@@ -1815,7 +1815,7 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
if (!info.artificial_cell) {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];");
}

print_rename_state_vars();
@@ -2062,7 +2062,9 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
printer->push_block("for (int id = 0; id < nodecount; id++)");
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto* _ppvar = _ml_arg->pdata[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
if (!info.artificial_cell) {
printer->add_line("inst.v_unused[id] = node_data.node_voltages[node_id];");
}

/**
* \todo Eigen solver node also emits IonCurVar variable in the functor
@@ -2135,6 +2137,7 @@ void CodegenNeuronCppVisitor::print_nrn_current(const BreakpointBlock& node) {
printer->fmt_push_block("static inline double nrn_current_{}({})",
info.mod_suffix,
get_parameter_str(args));
printer->add_line("inst.v_unused[id] = v;");
printer->add_line("double current = 0.0;");
print_statement_block(*block, false, false);
for (auto& current: info.currents) {
18 changes: 18 additions & 0 deletions test/usecases/voltage/accessors.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
NEURON {
SUFFIX accessors
NONSPECIFIC_CURRENT il
}

ASSIGNED {
v
il
}

BREAKPOINT {
il = 0.003
}


FUNCTION get_voltage() {
get_voltage = v
}
17 changes: 17 additions & 0 deletions test/usecases/voltage/ode.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
NEURON {
SUFFIX ode
NONSPECIFIC_CURRENT il
}

ASSIGNED {
il
v
}

FUNCTION voltage() {
voltage = 0.001 * v
}

BREAKPOINT {
il = voltage()
}
31 changes: 31 additions & 0 deletions test/usecases/voltage/state_ode.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
NEURON {
SUFFIX state_ode
NONSPECIFIC_CURRENT il
}

STATE {
X
}

ASSIGNED {
il
v
}

INITIAL {
X = v
}

BREAKPOINT {
SOLVE eqn
il = 0.001 * X
}

NONLINEAR eqn { LOCAL c
c = rate()
~ X = c
}

FUNCTION rate() {
rate = v
}
56 changes: 56 additions & 0 deletions test/usecases/voltage/test_voltage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from neuron import h, gui

import numpy as np


def test_voltage_access():
s = h.Section()
s.insert("accessors")

h.finitialize()
v = s(0.5).v
vinst = s(0.5).accessors.get_voltage()
# The voltage will be consistent right after
# finitialize.
assert vinst == v

for _ in range(4):
v = s(0.5).v
h.fadvance()
vinst = s(0.5).accessors.get_voltage()

# During timestepping the internal copy
# of the voltage lags behind the current
# voltage by some timestep.
assert vinst == v, f"{vinst = }, {v = }, delta = {vinst - v}"


def check_ode(mech_name, step):
s = h.Section()
s.insert(mech_name)

h.finitialize()

c = -0.001 / 1e-3

for _ in range(4):
v_expected = step(s(0.5).v, c)
h.fadvance()
# print(f"{s(0.5).v = }, {v_expected}")
np.testing.assert_approx_equal(s(0.5).v, v_expected, significant=10)


def test_breakpoint():
# Results in backward Euler.
check_ode("ode", lambda v, c: (1.0 - c * h.dt) ** (-1.0) * v)


def test_state():
# Effectively, the timing when states are computed results in backward Euler.
check_ode("state_ode", lambda v, c: (1.0 + c * h.dt) * v)


if __name__ == "__main__":
# test_voltage_access()
# test_breakpoint()
test_state()

0 comments on commit 804adcd

Please sign in to comment.