Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
# nest_integration_test.py
# test_nest_neuron_model_equivalence.py
#
# This file is part of NEST.
#
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_model_doc_title(model_fname: str):

@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
reason="This test does not support NEST 2")
class TestNestIntegration:
class TestNESTNeuronModelEquivalence:

def generate_all_models(self):
codegen_opts = {}
Expand Down Expand Up @@ -89,7 +89,7 @@ def generate_all_models(self):
suffix="_alt_int_nestml",
codegen_opts=alt_codegen_opts)

def test_nest_integration(self):
def test_nest_neuron_model_equivalence(self):
self.generate_all_models()
nest.Install("nestml_allmodels_module")
nest.Install("nestml_alt_allmodels_module")
Expand Down Expand Up @@ -179,19 +179,19 @@ def _test_model_equivalence_curr_inj(self, nest_model_name, nestml_model_name, g
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

neuron1 = nest.Create(nest_model_name, params=nest_model_parameters)
neuron2 = nest.Create(nestml_model_name, params=nestml_model_parameters)
nest_neuron = nest.Create(nest_model_name, params=nest_model_parameters)
nestml_neuron = nest.Create(nestml_model_name, params=nestml_model_parameters)
if model_initial_state is not None:
nest.SetStatus(neuron1, model_initial_state)
nest.SetStatus(neuron2, model_initial_state)
nest.SetStatus(nest_neuron, model_initial_state)
nest.SetStatus(nestml_neuron, model_initial_state)

# if gsl_error_tol is not None:
# nest.SetStatus(neuron2, {"gsl_error_tol": gsl_error_tol})
# nest.SetStatus(nestml_neuron, {"gsl_error_tol": gsl_error_tol})

dc = nest.Create("dc_generator", params={"amplitude": 0.})

nest.Connect(dc, neuron1)
nest.Connect(dc, neuron2)
nest.Connect(dc, nest_neuron)
nest.Connect(dc, nestml_neuron)

multimeter1 = nest.Create("multimeter")
multimeter2 = nest.Create("multimeter")
Expand All @@ -200,8 +200,8 @@ def _test_model_equivalence_curr_inj(self, nest_model_name, nestml_model_name, g
nest.SetStatus(multimeter1, {"record_from": [V_m_specifier]})
nest.SetStatus(multimeter2, {"record_from": [V_m_specifier]})

nest.Connect(multimeter1, neuron1)
nest.Connect(multimeter2, neuron2)
nest.Connect(multimeter1, nest_neuron)
nest.Connect(multimeter2, nestml_neuron)

if NESTTools.detect_nest_version().startswith("v2"):
sd_reference = nest.Create("spike_detector")
Expand All @@ -210,8 +210,8 @@ def _test_model_equivalence_curr_inj(self, nest_model_name, nestml_model_name, g
sd_reference = nest.Create("spike_recorder")
sd_testant = nest.Create("spike_recorder")

nest.Connect(neuron1, sd_reference)
nest.Connect(neuron2, sd_testant)
nest.Connect(nest_neuron, sd_reference)
nest.Connect(nestml_neuron, sd_testant)

nest.Simulate(t_pulse_start)
dc.amplitude = I_stim * 1E12 # 1E12: convert A to pA
Expand All @@ -236,7 +236,7 @@ def _test_model_equivalence_curr_inj(self, nest_model_name, nestml_model_name, g
for _ax in ax:
_ax.legend(loc="upper right")
_ax.grid()
plt.savefig("/tmp/nestml_nest_integration_test_pulse_[" + nest_model_name + "]_[" + nestml_model_name + "]_[I_stim=" + str(I_stim) + "].png")
plt.savefig("/tmp/test_nest_neuron_model_equivalence_pulse_[" + nest_model_name + "]_[" + nestml_model_name + "]_[I_stim=" + str(I_stim) + "].png")
plt.close(fig)

np.testing.assert_allclose(Vms1, Vms2)
Expand All @@ -262,19 +262,19 @@ def _test_model_equivalence_fI_curve(self, nest_model_name, nestml_model_name, g
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

neuron1 = nest.Create(nest_model_name, params=nest_model_parameters)
neuron2 = nest.Create(nestml_model_name, params=nestml_model_parameters)
nest_neuron = nest.Create(nest_model_name, params=nest_model_parameters)
nestml_neuron = nest.Create(nestml_model_name, params=nestml_model_parameters)
if model_initial_state is not None:
nest.SetStatus(neuron1, model_initial_state)
nest.SetStatus(neuron2, model_initial_state)
nest.SetStatus(nest_neuron, model_initial_state)
nest.SetStatus(nestml_neuron, model_initial_state)

# if gsl_error_tol is not None:
# nest.SetStatus(neuron2, {"gsl_error_tol": gsl_error_tol})
# nest.SetStatus(nestml_neuron, {"gsl_error_tol": gsl_error_tol})

dc = nest.Create("dc_generator", params={"amplitude": 1E12 * I_stim}) # 1E12: convert A to pA

nest.Connect(dc, neuron1)
nest.Connect(dc, neuron2)
nest.Connect(dc, nest_neuron)
nest.Connect(dc, nestml_neuron)

multimeter1 = nest.Create("multimeter")
multimeter2 = nest.Create("multimeter")
Expand All @@ -283,8 +283,8 @@ def _test_model_equivalence_fI_curve(self, nest_model_name, nestml_model_name, g
nest.SetStatus(multimeter1, {"record_from": [V_m_specifier]})
nest.SetStatus(multimeter2, {"record_from": [V_m_specifier]})

nest.Connect(multimeter1, neuron1)
nest.Connect(multimeter2, neuron2)
nest.Connect(multimeter1, nest_neuron)
nest.Connect(multimeter2, nestml_neuron)

if NESTTools.detect_nest_version().startswith("v2"):
sd_reference = nest.Create("spike_detector")
Expand All @@ -293,8 +293,8 @@ def _test_model_equivalence_fI_curve(self, nest_model_name, nestml_model_name, g
sd_reference = nest.Create("spike_recorder")
sd_testant = nest.Create("spike_recorder")

nest.Connect(neuron1, sd_reference)
nest.Connect(neuron2, sd_testant)
nest.Connect(nest_neuron, sd_reference)
nest.Connect(nestml_neuron, sd_testant)

nest.Simulate(t_stop)

Expand All @@ -317,7 +317,7 @@ def _test_model_equivalence_fI_curve(self, nest_model_name, nestml_model_name, g
_ax.legend(loc="upper right")
_ax.grid()
fig.suptitle("Rate: " + str(rate_testant[i]) + " Hz")
plt.savefig("/tmp/nestml_nest_integration_test_subthreshold_[" + nest_model_name + "]_[" + nestml_model_name + "]_[I_stim=" + str(I_stim) + "].png")
plt.savefig("/tmp/test_nest_neuron_model_fI_curve_[" + nest_model_name + "]_[" + nestml_model_name + "]_[I_stim=" + str(I_stim) + "].png")
plt.close(fig)

if TEST_PLOTS:
Expand All @@ -333,7 +333,7 @@ def _test_model_equivalence_fI_curve(self, nest_model_name, nestml_model_name, g
_ax.grid()
_ax.set_ylabel("Firing rate [Hz]")
ax[1].set_xlabel("$I_{inj}$ [pA]")
plt.savefig("/tmp/nestml_nest_integration_test_subthreshold_[" + nest_model_name + "]_[" + nestml_model_name + "].png")
plt.savefig("/tmp/test_nest_neuron_model_equivalence_subthreshold_[" + nest_model_name + "]_[" + nestml_model_name + "].png")
plt.close(fig)

for figsize, fname_snip in zip([(8, 5), (4, 3)], ["", "_small"]):
Expand Down Expand Up @@ -365,26 +365,35 @@ def _test_model_equivalence_psc(self, nest_model_name, nestml_model_name, gsl_er
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

neuron1 = nest.Create(nest_model_name, params=nest_model_parameters)
neuron2 = nest.Create(nestml_model_name, params=nestml_model_parameters)
nest_neuron = nest.Create(nest_model_name, params=nest_model_parameters)
nestml_neuron = nest.Create(nestml_model_name, params=nestml_model_parameters)

if model_initial_state is not None:
nest.SetStatus(neuron1, model_initial_state)
nest.SetStatus(neuron2, model_initial_state)
nest.SetStatus(nest_neuron, model_initial_state)
nest.SetStatus(nestml_neuron, model_initial_state)

# if gsl_error_tol is not None:
# nest.SetStatus(neuron2, {"gsl_error_tol": gsl_error_tol})
# nest.SetStatus(nestml_neuron, {"gsl_error_tol": gsl_error_tol})

spikegenerator = nest.Create("spike_generator",
params={"spike_times": spike_times, "spike_weights": spike_weights})

nest.Connect(spikegenerator, neuron1, syn_spec=syn_spec)
nest.Connect(spikegenerator, neuron2, syn_spec=syn_spec)
nest.Connect(spikegenerator, nest_neuron, syn_spec=syn_spec)

if "receptor_types" in nestml_neuron.get().keys() and len(nestml_neuron.get("receptor_types")) > 1:
# this NESTML neuron is written as having separate input ports for excitatory and inhibitory spikes
syn_spec_nestml = syn_spec
if syn_spec_nestml is None:
syn_spec_nestml = {}
syn_spec_nestml.update({"receptor_type": nestml_neuron.get("receptor_types")["EXC_SPIKES"]})
nest.Connect(spikegenerator, nestml_neuron, syn_spec=syn_spec_nestml)
else:
# this NESTML neuron is written as having one input port for excitatory and inhibitory spikes (with sign of the weight telling the difference)
nest.Connect(spikegenerator, nestml_neuron, syn_spec=syn_spec)

spike_recorder1 = nest.Create("spike_recorder")
spike_recorder2 = nest.Create("spike_recorder")
nest.Connect(neuron1, spike_recorder1)
nest.Connect(neuron2, spike_recorder2)
nest.Connect(nest_neuron, spike_recorder1)
nest.Connect(nestml_neuron, spike_recorder2)

multimeter1 = nest.Create("multimeter")
multimeter2 = nest.Create("multimeter")
Expand All @@ -393,8 +402,8 @@ def _test_model_equivalence_psc(self, nest_model_name, nestml_model_name, gsl_er
nest.SetStatus(multimeter1, {"record_from": [V_m_specifier]})
nest.SetStatus(multimeter2, {"record_from": [V_m_specifier]})

nest.Connect(multimeter1, neuron1)
nest.Connect(multimeter2, neuron2)
nest.Connect(multimeter1, nest_neuron)
nest.Connect(multimeter2, nestml_neuron)

nest.Simulate(400.)

Expand All @@ -416,7 +425,7 @@ def _test_model_equivalence_psc(self, nest_model_name, nestml_model_name, gsl_er
for _ax in ax:
_ax.legend(loc="upper right")
_ax.grid()
plt.savefig("/tmp/nestml_nest_integration_test_psc_[" + nest_model_name + "]_[" + nestml_model_name + "].png")
plt.savefig("/tmp/test_nest_neuron_model_equivalence_psc_[" + nest_model_name + "]_[" + nestml_model_name + "].png")
plt.close(fig)

np.testing.assert_allclose(ts1, ts2)
Expand Down
Loading