-
Notifications
You must be signed in to change notification settings - Fork 6
/
20240115-reparametrization.patch
97 lines (86 loc) · 3.68 KB
/
20240115-reparametrization.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
diff --git a/spiking_simulations/LPLConnection.cpp b/spiking_simulations/LPLConnection.cpp
index dfda86e..5e3e47e 100755
--- a/spiking_simulations/LPLConnection.cpp
+++ b/spiking_simulations/LPLConnection.cpp
@@ -38,7 +38,7 @@ void LPLConnection::init(AurynFloat eta, AurynFloat tau_mean, AurynFloat tau_sig
eta_ = eta;
auryn::logger->parameter("eta",eta);
- delta = 1e-5; // strength of transmitter triggered plasticity
+ delta = 1e-3; // strength of transmitter triggered plasticity
auryn::logger->parameter("delta",delta);
epsilon = 1e-32; //!< for division by zero cases
@@ -187,7 +187,7 @@ LPLConnection::LPLConnection(
name)
{
init(eta, tau_mean, tau_sigma2, maxweight);
- lambda_ = lambda;
+ lambda_ = lambda*auryn_timestep;
phi_ = phi;
if ( name.empty() )
set_name("LPLConnection");
@@ -247,10 +247,10 @@ void LPLConnection::compute_err()
err->add_specific(trspk, 1.0);
}
temp->copy(tr_post_sigma2);
- temp->add(1e-3); // TODO make a parameter called xi to be used denominator
+ temp->add(1e-7); // TODO make a parameter called xi to be used denominator
err->div(temp);
- err->scale(lambda_);
- err->add(delta); // add transmitter triggered plasticity
+ err->scale(lambda_); // TESTING new changes
+ // err->add(delta); // add transmitter triggered plasticity
// Compute -dzdt error signal on spike trains using the van Rossum trick
for ( unsigned int i = 0 ; i < sc->size(); ++i ) add_to_err(sc->at(i), -1.0*phi_);
@@ -258,9 +258,11 @@ void LPLConnection::compute_err()
// some qnd monitoring
if ( false && sys->get_clock()%100000==0) std::cout << std::scientific
+ << " time " << sys->get_time()
<< " avgsqrerr " << avgsqrerr->mean()
<< ", mean firing rate " << tr_post_mean->mean()/tr_post_mean->get_tau()
<< ", tr_post_sigma2 " << tr_post_sigma2->mean()
+ << ", err var " << err->var()
<< std::endl;
@@ -331,6 +333,23 @@ void LPLConnection::process_plasticity()
el_sum->add_specific(didx, de);
}
}
+
+
+ // add transmitter triggered plasticity term
+ // loop over all pre spikes
+ for (SpikeContainer::const_iterator spike = src->get_spikes()->begin() ; // spike = pre_spike
+ spike != src->get_spikes()->end() ; ++spike ) {
+ // loop over all postsynaptic target cells
+ for (const NeuronID * c = w->get_row_begin(*spike) ;
+ c != w->get_row_end(*spike) ;
+ ++c ) { // c = post index
+
+ // determines the weight of connection
+ AurynWeight * weight = w->get_data_ptr(c);
+ const AurynLong didx = w->data_ptr_to_didx(weight);
+ el_sum->add_specific(didx, delta); // add the transmitter triggered term
+ }
+ }
}
@@ -482,7 +501,7 @@ void LPLConnection::evolve()
temp->sqr();
// add to moving tr_post_sigma2
- temp->scale(1.0/tr_post_sigma2->get_tau()); // correct for mean time scale
+ temp->scale(1.0*auryn_timestep/tr_post_sigma2->get_tau()); // correct for mean time scale
tr_post_sigma2->add(temp);
tr_post_sigma2->evolve();
}
diff --git a/spiking_simulations/sim_lpl_spiking.cpp b/spiking_simulations/sim_lpl_spiking.cpp
index faf2c9c..e3ec8ed 100755
--- a/spiking_simulations/sim_lpl_spiking.cpp
+++ b/spiking_simulations/sim_lpl_spiking.cpp
@@ -421,6 +421,11 @@ int main(int ac, char* av[])
con_lpl->set_tau_rms(tau_rms);
con_lpl->set_max_weight(wmax);
+ // record LPL dynamic variables
+ const NeuronID rec_unit = 7;
+ StateMonitor * statmon_lpl1 = new StateMonitor( con_lpl->tr_post_mean, rec_unit, sys->fn("tr_post_mean","dat"), 1.0 );
+ StateMonitor * statmon_lpl2 = new StateMonitor( con_lpl->tr_post_sigma2, rec_unit, sys->fn("tr_post_sigma2","dat"), 1.0 );
+
if ( w_ee > 0.0 ) {
LPLConnection * con_lpl_rec = new LPLConnection(neurons_e, neurons_e, w_ext, sparseness, eta, lambda, phi, tau_mean, tau_sigma2);