Skip to content

Commit

Permalink
Merge pull request #119 from PKU-NIP-Lab/whole-brain-modeling
Browse files Browse the repository at this point in the history
fix bugs
  • Loading branch information
chaoming0625 authored Mar 21, 2022
2 parents d2f9254 + ed4e5e5 commit 3086c69
Show file tree
Hide file tree
Showing 52 changed files with 838 additions and 412 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,18 @@ runner.run(100.)



Numerical methods for delay differential equations (SDEs).
Numerical methods for delay differential equations (SDEs).

```python
xdelay = bm.FixedLenDelay(1, delay_len=1., before_t0=1., dt=0.01)
xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01)


@bp.ddeint(method='rk4', state_delays={'x': xdelay})
def second_order_eq(x, y, t):
dx = y
dy = -y - 2*x - 0.5*xdelay(t-1)
return dx, dy
dx = y
dy = -y - 2 * x - 0.5 * xdelay(t - 1)
return dx, dy


runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01)
runner.run(100.)
Expand Down
6 changes: 4 additions & 2 deletions brainpy/analysis/utils/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
import numpy as np
from brainpy.tools.others import numba_jit


__all__ = [
Expand All @@ -10,7 +11,7 @@
]


# @tools.numba_jit
@numba_jit
def _f1(arr, grad, tol):
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
indexes = np.where(condition)[0]
Expand All @@ -19,7 +20,8 @@ def _f1(arr, grad, tol):
length = np.max(data) - np.min(data)
a = arr[indexes[-2]]
b = arr[indexes[-1]]
if np.abs(a - b) <= tol * length:
# TODO: how to choose length threshold, 1e-3?
if length > 1e-3 and np.abs(a - b) <= tol * length:
return indexes[-2:]
return np.array([-1, -1])

Expand Down
3 changes: 1 addition & 2 deletions brainpy/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def model_transform(model):
new_model = []
for intg in model:
if isinstance(intg.f, JointEq):
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt, dyn_var=intg.dyn_var)
for eq in intg.f.eqs])
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs])
else:
new_model.append(intg)

Expand Down
27 changes: 27 additions & 0 deletions brainpy/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-


__all__ = [
'is_checking',
'turn_on',
'turn_off',
]

_check = True


def is_checking():
"""Whether the checking is turn on."""
return _check


def turn_on():
"""Turn on the checking."""
global _check
_check = True


def turn_off():
"""Turn off the checking."""
global _check
_check = False
2 changes: 1 addition & 1 deletion brainpy/datasets/chaotic_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
assert isinstance(inits, (bm.ndarray, jnp.ndarray))

rng = bm.random.RandomState(seed)
xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt)
xdelay = bm.TimeDelay(inits.shape, tau, dt=dt)
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5)

@ddeint(method=method, state_delays={'x': xdelay})
Expand Down
138 changes: 6 additions & 132 deletions brainpy/dyn/neurons/rate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
__all__ = [
'FHN',
'FeedbackFHN',
'MeanFieldQIF',
]


Expand Down Expand Up @@ -197,7 +196,6 @@ def __init__(self,
tau: Parameter = 12.5,
mu: Parameter = 1.6886,
v0: Parameter = -1,
Vth: Parameter = 1.8,
method: str = 'rk4',
name: str = None):
super(FeedbackFHN, self).__init__(size=size, name=name)
Expand All @@ -209,23 +207,21 @@ def __init__(self,
self.tau = tau
self.mu = mu # feedback strength
self.v0 = v0 # resting potential
self.Vth = Vth

# variables
self.w = bm.Variable(bm.zeros(self.num))
self.V = bm.Variable(bm.zeros(self.num))
self.Vdelay = bm.FixedLenDelay(self.num, self.delay)
self.Vdelay = bm.TimeDelay(self.num, self.delay, interp_method='round')
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

# integral
self.integral = ddeint(method=method, f=self.derivative,
self.integral = ddeint(method=method,
f=self.derivative,
state_delays={'V': self.Vdelay})

def dV(self, V, t, w, Vdelay):
def dV(self, V, t, w):
return (V - V * V * V / 3 - w + self.input +
self.mu * (Vdelay(t - self.delay) - self.v0))
self.mu * (self.Vdelay(t - self.delay) - self.v0))

def dw(self, w, t, V):
return (V + self.a - self.b * w) / self.tau
Expand All @@ -235,129 +231,7 @@ def derivative(self):
return JointEq([self.dV, self.dw])

def update(self, _t, _dt):
V, w = self.integral(self.V, self.w, _t, Vdelay=self.Vdelay, dt=_dt)
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
V, w = self.integral(self.V, self.w, _t, dt=_dt)
self.V.value = V
self.w.value = w
self.input[:] = 0.


class MeanFieldQIF(NeuGroup):
r"""A mean-field model of a quadratic integrate-and-fire neuron population.
**Model Descriptions**
The QIF population mean-field model, which has been derived from a
population of all-to-all coupled QIF neurons in [5]_.
The model equations are given by:
.. math::
\begin{aligned}
\tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\
\tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2}
\end{aligned}
where :math:`r` is the average firing rate and :math:`v` is the
average membrane potential of the QIF population [5]_.
This mean-field model is an exact representation of the macroscopic
firing rate and membrane potential dynamics of a spiking neural network
consisting of QIF neurons with Lorentzian distributed background
excitabilities. While the mean-field derivation is mathematically
only valid for all-to-all coupled populations of infinite size, it
has been shown that there is a close correspondence between the
mean-field model and neural populations with sparse coupling and
population sizes of a few thousand neurons [6]_.
**Model Parameters**
============= ============== ======== ========================
**Parameter** **Init Value** **Unit** **Explanation**
------------- -------------- -------- ------------------------
tau 1 ms the population time constant
eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population
delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability
J 15 \ the strength of the recurrent coupling inside the population
============= ============== ======== ========================
References
----------
.. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for
networks of spiking neurons. Physical Review X, 5:021028,
https://doi.org/10.1103/PhysRevX.5.021028.
.. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description
of Bursting Dynamics in Spiking Neural Networks with Short-Term
Adaptation. Neural Computation 32.9 (2020): 1615-1634.
"""

def __init__(self,
size: Shape,
tau: Parameter = 1.,
eta: Parameter = -5.0,
delta: Parameter = 1.0,
J: Parameter = 15.,
method: str = 'exp_auto',
name: str = None):
super(MeanFieldQIF, self).__init__(size=size, name=name)

# parameters
self.tau = tau #
self.eta = eta # the mean of a Lorenzian distribution over the neural excitability in the population
self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability
self.J = J # the strength of the recurrent coupling inside the population

# variables
self.r = bm.Variable(bm.ones(1))
self.V = bm.Variable(bm.ones(1))
self.input = bm.Variable(bm.zeros(1))

# functions
self.integral = odeint(self.derivative, method=method)

def dr(self, r, t, v):
return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau

def dV(self, v, t, r):
return (v ** 2 + self.eta + self.input + self.J * r * self.tau -
(bm.pi * r * self.tau) ** 2) / self.tau

@property
def derivative(self):
return JointEq([self.dV, self.dr])

def update(self, _t, _dt):
self.V.value, self.r.value = self.integral(self.V, self.r, _t, _dt)
self.integral[:] = 0.



class VanDerPolOscillator(NeuGroup):
pass


class ThetaNeuron(NeuGroup):
pass


class MeanFieldQIFWithSFA(NeuGroup):
pass


class JansenRitModel(NeuGroup):
pass


class WilsonCowanModel(NeuGroup):
pass

class StuartLandauOscillator(NeuGroup):
pass


class KuramotoOscillator(NeuGroup):
pass

4 changes: 4 additions & 0 deletions brainpy/dyn/rates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-

from .base import *
from .fhn import *
34 changes: 34 additions & 0 deletions brainpy/dyn/rates/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-

from brainpy.dyn.base import DynamicalSystem
from brainpy.tools.others import to_size, size2num
from brainpy.types import Shape

__all__ = [
'RateModel',
]


class RateModel(DynamicalSystem):
"""Base class of rate models."""

def __init__(self,
size: Shape,
name: str = None):
super(RateModel, self).__init__(name=name)

self.size = to_size(size)
self.num = size2num(self.size)

def update(self, _t, _dt):
"""The function to specify the updating rule.
Parameters
----------
_t : float
The current time.
_dt : float
The time step.
"""
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
f'implement "update" function.')
Loading

0 comments on commit 3086c69

Please sign in to comment.