diff --git a/README.md b/README.md index 2db5d0b1e..15a6b8dbb 100644 --- a/README.md +++ b/README.md @@ -147,16 +147,18 @@ runner.run(100.) -Numerical methods for delay differential equations (SDEs). +Numerical methods for delay differential equations (SDEs). ```python -xdelay = bm.FixedLenDelay(1, delay_len=1., before_t0=1., dt=0.01) +xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01) + @bp.ddeint(method='rk4', state_delays={'x': xdelay}) def second_order_eq(x, y, t): - dx = y - dy = -y - 2*x - 0.5*xdelay(t-1) - return dx, dy + dx = y + dy = -y - 2 * x - 0.5 * xdelay(t - 1) + return dx, dy + runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01) runner.run(100.) diff --git a/brainpy/analysis/utils/measurement.py b/brainpy/analysis/utils/measurement.py index 3cf4e76b3..24d7d9dd0 100644 --- a/brainpy/analysis/utils/measurement.py +++ b/brainpy/analysis/utils/measurement.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np +from brainpy.tools.others import numba_jit __all__ = [ @@ -10,7 +11,7 @@ ] -# @tools.numba_jit +@numba_jit def _f1(arr, grad, tol): condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0) indexes = np.where(condition)[0] @@ -19,7 +20,8 @@ def _f1(arr, grad, tol): length = np.max(data) - np.min(data) a = arr[indexes[-2]] b = arr[indexes[-1]] - if np.abs(a - b) <= tol * length: + # TODO: how to choose length threshold, 1e-3? + if length > 1e-3 and np.abs(a - b) <= tol * length: return indexes[-2:] return np.array([-1, -1]) diff --git a/brainpy/analysis/utils/model.py b/brainpy/analysis/utils/model.py index d499394d0..2a3ab2b1d 100644 --- a/brainpy/analysis/utils/model.py +++ b/brainpy/analysis/utils/model.py @@ -49,8 +49,7 @@ def model_transform(model): new_model = [] for intg in model: if isinstance(intg.f, JointEq): - new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt, dyn_var=intg.dyn_var) - for eq in intg.f.eqs]) + new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs]) else: new_model.append(intg) diff --git a/brainpy/check.py b/brainpy/check.py new file mode 100644 index 000000000..55fc5a9d8 --- /dev/null +++ b/brainpy/check.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + + +__all__ = [ + 'is_checking', + 'turn_on', + 'turn_off', +] + +_check = True + + +def is_checking(): + """Whether the checking is turn on.""" + return _check + + +def turn_on(): + """Turn on the checking.""" + global _check + _check = True + + +def turn_off(): + """Turn off the checking.""" + global _check + _check = False diff --git a/brainpy/datasets/chaotic_systems.py b/brainpy/datasets/chaotic_systems.py index 75d1c03b4..9da48420a 100644 --- a/brainpy/datasets/chaotic_systems.py +++ b/brainpy/datasets/chaotic_systems.py @@ -167,7 +167,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65, assert isinstance(inits, (bm.ndarray, jnp.ndarray)) rng = bm.random.RandomState(seed) - xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt) + xdelay = bm.TimeDelay(inits.shape, tau, dt=dt) xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5) @ddeint(method=method, state_delays={'x': xdelay}) diff --git a/brainpy/dyn/neurons/rate_models.py b/brainpy/dyn/neurons/rate_models.py index 12385f7d9..8a11af87c 100644 --- a/brainpy/dyn/neurons/rate_models.py +++ b/brainpy/dyn/neurons/rate_models.py @@ -10,7 +10,6 @@ __all__ = [ 'FHN', 'FeedbackFHN', - 'MeanFieldQIF', ] @@ -197,7 +196,6 @@ def __init__(self, tau: Parameter = 12.5, mu: Parameter = 1.6886, v0: Parameter = -1, - Vth: Parameter = 1.8, method: str = 'rk4', name: str = None): super(FeedbackFHN, self).__init__(size=size, name=name) @@ -209,23 +207,21 @@ def __init__(self, self.tau = tau self.mu = mu # feedback strength self.v0 = v0 # resting potential - self.Vth = Vth # variables self.w = bm.Variable(bm.zeros(self.num)) self.V = bm.Variable(bm.zeros(self.num)) - self.Vdelay = bm.FixedLenDelay(self.num, self.delay) + self.Vdelay = bm.TimeDelay(self.num, self.delay, interp_method='round') self.input = bm.Variable(bm.zeros(self.num)) - self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) - self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) # integral - self.integral = ddeint(method=method, f=self.derivative, + self.integral = ddeint(method=method, + f=self.derivative, state_delays={'V': self.Vdelay}) - def dV(self, V, t, w, Vdelay): + def dV(self, V, t, w): return (V - V * V * V / 3 - w + self.input + - self.mu * (Vdelay(t - self.delay) - self.v0)) + self.mu * (self.Vdelay(t - self.delay) - self.v0)) def dw(self, w, t, V): return (V + self.a - self.b * w) / self.tau @@ -235,129 +231,7 @@ def derivative(self): return JointEq([self.dV, self.dw]) def update(self, _t, _dt): - V, w = self.integral(self.V, self.w, _t, Vdelay=self.Vdelay, dt=_dt) - self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) - self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike) + V, w = self.integral(self.V, self.w, _t, dt=_dt) self.V.value = V self.w.value = w self.input[:] = 0. - - -class MeanFieldQIF(NeuGroup): - r"""A mean-field model of a quadratic integrate-and-fire neuron population. - - **Model Descriptions** - - The QIF population mean-field model, which has been derived from a - population of all-to-all coupled QIF neurons in [5]_. - The model equations are given by: - - .. math:: - - \begin{aligned} - \tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\ - \tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2} - \end{aligned} - - where :math:`r` is the average firing rate and :math:`v` is the - average membrane potential of the QIF population [5]_. - - This mean-field model is an exact representation of the macroscopic - firing rate and membrane potential dynamics of a spiking neural network - consisting of QIF neurons with Lorentzian distributed background - excitabilities. While the mean-field derivation is mathematically - only valid for all-to-all coupled populations of infinite size, it - has been shown that there is a close correspondence between the - mean-field model and neural populations with sparse coupling and - population sizes of a few thousand neurons [6]_. - - **Model Parameters** - - ============= ============== ======== ======================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------ - tau 1 ms the population time constant - eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population - delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability - J 15 \ the strength of the recurrent coupling inside the population - ============= ============== ======== ======================== - - - References - ---------- - .. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for - networks of spiking neurons. Physical Review X, 5:021028, - https://doi.org/10.1103/PhysRevX.5.021028. - .. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description - of Bursting Dynamics in Spiking Neural Networks with Short-Term - Adaptation. Neural Computation 32.9 (2020): 1615-1634. - - """ - - def __init__(self, - size: Shape, - tau: Parameter = 1., - eta: Parameter = -5.0, - delta: Parameter = 1.0, - J: Parameter = 15., - method: str = 'exp_auto', - name: str = None): - super(MeanFieldQIF, self).__init__(size=size, name=name) - - # parameters - self.tau = tau # - self.eta = eta # the mean of a Lorenzian distribution over the neural excitability in the population - self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability - self.J = J # the strength of the recurrent coupling inside the population - - # variables - self.r = bm.Variable(bm.ones(1)) - self.V = bm.Variable(bm.ones(1)) - self.input = bm.Variable(bm.zeros(1)) - - # functions - self.integral = odeint(self.derivative, method=method) - - def dr(self, r, t, v): - return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau - - def dV(self, v, t, r): - return (v ** 2 + self.eta + self.input + self.J * r * self.tau - - (bm.pi * r * self.tau) ** 2) / self.tau - - @property - def derivative(self): - return JointEq([self.dV, self.dr]) - - def update(self, _t, _dt): - self.V.value, self.r.value = self.integral(self.V, self.r, _t, _dt) - self.integral[:] = 0. - - - -class VanDerPolOscillator(NeuGroup): - pass - - -class ThetaNeuron(NeuGroup): - pass - - -class MeanFieldQIFWithSFA(NeuGroup): - pass - - -class JansenRitModel(NeuGroup): - pass - - -class WilsonCowanModel(NeuGroup): - pass - -class StuartLandauOscillator(NeuGroup): - pass - - -class KuramotoOscillator(NeuGroup): - pass - diff --git a/brainpy/dyn/rates/__init__.py b/brainpy/dyn/rates/__init__.py new file mode 100644 index 000000000..b371df655 --- /dev/null +++ b/brainpy/dyn/rates/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .base import * +from .fhn import * diff --git a/brainpy/dyn/rates/base.py b/brainpy/dyn/rates/base.py new file mode 100644 index 000000000..8fb14e1b3 --- /dev/null +++ b/brainpy/dyn/rates/base.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +from brainpy.dyn.base import DynamicalSystem +from brainpy.tools.others import to_size, size2num +from brainpy.types import Shape + +__all__ = [ + 'RateModel', +] + + +class RateModel(DynamicalSystem): + """Base class of rate models.""" + + def __init__(self, + size: Shape, + name: str = None): + super(RateModel, self).__init__(name=name) + + self.size = to_size(size) + self.num = size2num(self.size) + + def update(self, _t, _dt): + """The function to specify the updating rule. + + Parameters + ---------- + _t : float + The current time. + _dt : float + The time step. + """ + raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' + f'implement "update" function.') diff --git a/brainpy/dyn/rates/fhn.py b/brainpy/dyn/rates/fhn.py new file mode 100644 index 000000000..72b9c49b9 --- /dev/null +++ b/brainpy/dyn/rates/fhn.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- + +import brainpy.math as bm +from brainpy.integrators import odeint, sdeint, JointEq +from brainpy.types import Parameter, Shape +from brainpy.tools.checking import check_float +from .base import RateModel + +__all__ = [ + 'FHN' +] + + +class FHN(RateModel): + r"""FitzHugh-Nagumo system used in [1]_. + + .. math:: + + \frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\ + \tau \frac{dy}{dt} = (V - \delta - \epsilon w) + + Parameters + ---------- + size: Shape + The model size. + + coupling: str + The way of coupling. + gc: float + The global coupling strength. + signal_speed: float + Signal transmission speed between areas. + sc_mat: optional, tensor + Structural connectivity matrix. Adjacency matrix of coupling strengths, + will be normalized to 1. If not given, then a single node simulation + will be assumed. Default None + fl_mat: optional, tensor + Fiber length matrix. Will be used for computing the + delay matrix together with the signal transmission + speed parameter `signal_speed`. Default None. + + References + ---------- + .. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo + revisited: Types of bifurcations, periodical forcing and stability + regions by a Lyapunov functional. International journal of + bifurcation and chaos, 14(03), 913-925. + + """ + + def __init__(self, + size: Shape, + + # fhn parameters + alpha: Parameter = 3.0, + beta: Parameter = 4.0, + gamma: Parameter = -1.5, + delta: Parameter = 0.0, + epsilon: Parameter = 0.5, + tau: Parameter = 20.0, + + # noise parameters + x_ou_mean: Parameter = 0.0, + y_ou_mean: Parameter = 0.0, + ou_sigma: Parameter = 0.0, + ou_tau: Parameter = 5.0, + + # coupling parameters + coupling: str = 'diffusive', + gc=0.6, + signal_speed=20.0, + sc_mat=None, + fl_mat=None, + + # other parameters + method: str = None, + name: str = None): + super(FHN, self).__init__(size, name=name) + + # model parameters + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.delta = delta + self.epsilon = epsilon + self.tau = tau + + # noise parameters + self.x_ou_mean = x_ou_mean # mV/ms, OU process + self.y_ou_mean = y_ou_mean # mV/ms, OU process + self.ou_sigma = ou_sigma # mV/ms/sqrt(ms), noise intensity + self.ou_tau = ou_tau # ms, timescale of the Ornstein-Uhlenbeck noise process + + # coupling parameters + # ---- + # The coupling parameter determines how nodes are coupled. + # "diffusive" for diffusive coupling, + # "additive" for additive coupling + self.coupling = coupling + assert coupling in ['diffusive', 'additive'], (f'Only support "diffusive" and "additive" ' + f'coupling, while we got {coupling}') + check_float(gc, 'gc', allow_none=False, allow_int=False) + self.gc = gc # global coupling strength + check_float(signal_speed, 'signal_speed', allow_none=False, allow_int=True) + self.signal_speed = signal_speed # signal transmission speed between areas + + + # variables + self.x = bm.Variable(bm.random.random(self.num) * 0.05) + self.y = bm.Variable(bm.random.randint(self.num) * 0.05) + self.x_ou = bm.Variable(bm.ones(self.num) * x_ou_mean) + self.y_ou = bm.Variable(bm.ones(self.num) * y_ou_mean) + self.x_ext = bm.Variable(bm.zeros(self.num)) + self.y_ext = bm.Variable(bm.zeros(self.num)) + + # integral functions + self.int_ou = sdeint(f=self.df_ou, g=self.dg_ou, method='euler') + self.int_xy = odeint(f=JointEq([self.dx, self.dy]), method=method) + + def dx(self, x, t, y, x_ext): + return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext + + def dy(self, y, t, x, y_ext=0.): + return (x - self.delta - self.epsilon * y + y_ext) / self.tau + + def df_ou(self, x_ou, y_ou, t): + f_x_ou = (self.x_ou_mean - x_ou) / self.ou_tau + f_y_ou = (self.y_ou_mean - y_ou) / self.ou_tau + return f_x_ou, f_y_ou + + def dg_ou(self, x_ou, y_ou, t): + return self.ou_sigma, self.ou_sigma + + def update(self, _t, _dt): + x_ext = self.x_ext + self.x_ou + y_ext = self.y_ext + self.y_ou + x, y = self.int_xy(self.x, self.y, _t, x_ext=x_ext, y_ext=y_ext, dt=_dt) + self.x.value = x + self.y.value = y + x_ou, y_ou = self.int_ou(self.x_ou, self.y_ou, _t, _dt) + self.x_ou.value = x_ou + self.y_ou.value = y_ou diff --git a/brainpy/dyn/rates/models.py b/brainpy/dyn/rates/models.py new file mode 100644 index 000000000..fd16302dd --- /dev/null +++ b/brainpy/dyn/rates/models.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +import brainpy.math as bm +from brainpy.integrators import odeint, sdeint, JointEq +from brainpy.types import Parameter, Shape +from .base import RateModel + +__all__ = [ +] + +class JansenRitModel(RateModel): + pass + + +class WilsonCowanModel(RateModel): + pass + + +class StuartLandauOscillator(RateModel): + pass + + +class KuramotoOscillator(RateModel): + pass + diff --git a/brainpy/dyn/rates/qif.py b/brainpy/dyn/rates/qif.py new file mode 100644 index 000000000..838482ada --- /dev/null +++ b/brainpy/dyn/rates/qif.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +import brainpy.math as bm +from brainpy.integrators import odeint, JointEq +from brainpy.types import Parameter, Shape +from .base import RateModel + +__all__ = [ + 'MeanFieldQIF' +] + + +class MeanFieldQIF(RateModel): + r"""A mean-field model of a quadratic integrate-and-fire neuron population. + + **Model Descriptions** + + The QIF population mean-field model, which has been derived from a + population of all-to-all coupled QIF neurons in [5]_. + The model equations are given by: + + .. math:: + + \begin{aligned} + \tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\ + \tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2} + \end{aligned} + + where :math:`r` is the average firing rate and :math:`v` is the + average membrane potential of the QIF population [5]_. + + This mean-field model is an exact representation of the macroscopic + firing rate and membrane potential dynamics of a spiking neural network + consisting of QIF neurons with Lorentzian distributed background + excitabilities. While the mean-field derivation is mathematically + only valid for all-to-all coupled populations of infinite size, it + has been shown that there is a close correspondence between the + mean-field model and neural populations with sparse coupling and + population sizes of a few thousand neurons [6]_. + + **Model Parameters** + + ============= ============== ======== ======================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------ + tau 1 ms the population time constant + eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population + delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability + J 15 \ the strength of the recurrent coupling inside the population + ============= ============== ======== ======================== + + + References + ---------- + .. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for + networks of spiking neurons. Physical Review X, 5:021028, + https://doi.org/10.1103/PhysRevX.5.021028. + .. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description + of Bursting Dynamics in Spiking Neural Networks with Short-Term + Adaptation. Neural Computation 32.9 (2020): 1615-1634. + + """ + + def __init__(self, + size: Shape, + tau: Parameter = 1., + eta: Parameter = -5.0, + delta: Parameter = 1.0, + J: Parameter = 15., + method: str = 'exp_auto', + name: str = None): + super(MeanFieldQIF, self).__init__(size=size, name=name) + + # parameters + self.tau = tau # + self.eta = eta # the mean of a Lorenzian distribution over the neural excitability in the population + self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability + self.J = J # the strength of the recurrent coupling inside the population + + # variables + self.r = bm.Variable(bm.ones(1)) + self.V = bm.Variable(bm.ones(1)) + self.input = bm.Variable(bm.zeros(1)) + + # functions + self.integral = odeint(self.derivative, method=method) + + def dr(self, r, t, v): + return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau + + def dV(self, v, t, r, I_ext): + return (v ** 2 + self.eta + I_ext + self.J * r * self.tau - + (bm.pi * r * self.tau) ** 2) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.dr]) + + def update(self, _t, _dt): + v, r = self.integral(self.V, self.r, t=_t, I_ext=self.input, dt=_dt) + self.V.value = v + self.r.value = r + self.input[:] = 0. + + +class ThetaNeuron(RateModel): + pass + + +class MeanFieldQIFWithSFA(RateModel): + pass diff --git a/brainpy/dyn/rates/vdp.py b/brainpy/dyn/rates/vdp.py new file mode 100644 index 000000000..d0de53789 --- /dev/null +++ b/brainpy/dyn/rates/vdp.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +from .base import RateModel + + +class VanDerPolOscillator(RateModel): + pass + diff --git a/brainpy/integrators/dde/base.py b/brainpy/integrators/dde/base.py index 413f6a486..380fb6ba4 100644 --- a/brainpy/integrators/dde/base.py +++ b/brainpy/integrators/dde/base.py @@ -25,7 +25,7 @@ def __init__( dt: Union[float, int] = None, name: str = None, show_code: bool = False, - state_delays: Dict[str, bm.FixedLenDelay] = None, + state_delays: Dict[str, bm.TimeDelay] = None, neutral_delays: Dict[str, bm.NeutralDelay] = None, ): dt = bm.get_dt() if dt is None else dt @@ -59,7 +59,7 @@ def __init__( # delays self._state_delays = dict() if state_delays is not None: - check_dict_data(state_delays, key_type=str, val_type=bm.FixedLenDelay) + check_dict_data(state_delays, key_type=str, val_type=bm.TimeDelay) for key, delay in state_delays.items(): if key not in self.variables: raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') diff --git a/brainpy/integrators/dde/explicit_rk.py b/brainpy/integrators/dde/explicit_rk.py index 8ba7bcf94..f01f7cc8c 100644 --- a/brainpy/integrators/dde/explicit_rk.py +++ b/brainpy/integrators/dde/explicit_rk.py @@ -4,6 +4,7 @@ from brainpy.integrators.dde.base import DDEIntegrator from brainpy.integrators.ode import common from brainpy.integrators.utils import compile_code, check_kws +from brainpy.integrators.dde.generic import register_dde_integrator __all__ = [ 'ExplicitRKIntegrator', @@ -47,8 +48,6 @@ def __init__(self, f, **kwargs): def integral(*vars, **kwargs): pass - - self.build() def build(self): @@ -72,24 +71,36 @@ class Euler(ExplicitRKIntegrator): C = [0] +register_dde_integrator('euler', Euler) + + class MidPoint(ExplicitRKIntegrator): A = [(), (0.5,)] B = [0, 1] C = [0, 0.5] +register_dde_integrator('midpoint', MidPoint) + + class Heun2(ExplicitRKIntegrator): A = [(), (1,)] B = [0.5, 0.5] C = [0, 1] +register_dde_integrator('heun2', Heun2) + + class Ralston2(ExplicitRKIntegrator): A = [(), ('2/3',)] B = [0.25, 0.75] C = [0, '2/3'] +register_dde_integrator('ralston2', Ralston2) + + class RK2(ExplicitRKIntegrator): def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, show_code=False): self.A = [(), (beta,)] @@ -98,43 +109,67 @@ def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, show_code=F super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code) +register_dde_integrator('rk2', RK2) + + class RK3(ExplicitRKIntegrator): A = [(), (0.5,), (-1, 2)] B = ['1/6', '2/3', '1/6'] C = [0, 0.5, 1] +register_dde_integrator('rk3', RK3) + + class Heun3(ExplicitRKIntegrator): A = [(), ('1/3',), (0, '2/3')] B = [0.25, 0, 0.75] C = [0, '1/3', '2/3'] +register_dde_integrator('heun3', Heun3) + + class Ralston3(ExplicitRKIntegrator): A = [(), (0.5,), (0, 0.75)] B = ['2/9', '1/3', '4/9'] C = [0, 0.5, 0.75] +register_dde_integrator('ralston3', Ralston3) + + class SSPRK3(ExplicitRKIntegrator): A = [(), (1,), (0.25, 0.25)] B = ['1/6', '1/6', '2/3'] C = [0, 1, 0.5] +register_dde_integrator('ssprk3', SSPRK3) + + class RK4(ExplicitRKIntegrator): A = [(), (0.5,), (0., 0.5), (0., 0., 1)] B = ['1/6', '1/3', '1/3', '1/6'] C = [0, 0.5, 0.5, 1] +register_dde_integrator('rk4', RK4) + + class Ralston4(ExplicitRKIntegrator): A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)] B = [.17476028, -.55148066, 1.20553560, .17118478] C = [0, .4, .45573725, 1] +register_dde_integrator('ralston4', Ralston4) + + class RK4Rule38(ExplicitRKIntegrator): A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] B = [0.125, 0.375, 0.375, 0.125] C = [0, '1/3', '2/3', 1] + + +register_dde_integrator('rk4_38rule', RK4Rule38) diff --git a/brainpy/integrators/dde/generic.py b/brainpy/integrators/dde/generic.py index 8eb5a0ec4..29087725a 100644 --- a/brainpy/integrators/dde/generic.py +++ b/brainpy/integrators/dde/generic.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from .base import DDEIntegrator -from .explicit_rk import * __all__ = [ 'ddeint', @@ -12,19 +11,6 @@ ] name2method = { - # explicit RK - 'euler': Euler, 'Euler': Euler, - 'midpoint': MidPoint, 'MidPoint': MidPoint, - 'heun2': Heun2, 'Heun2': Heun2, - 'ralston2': Ralston2, 'Ralston2': Ralston2, - 'rk2': RK2, 'RK2': RK2, - 'rk3': RK3, 'RK3': RK3, - 'heun3': Heun3, 'Heun3': Heun3, - 'ralston3': Ralston3, 'Ralston3': Ralston3, - 'ssprk3': SSPRK3, 'SSPRK3': SSPRK3, - 'rk4': RK4, 'RK4': RK4, - 'ralston4': Ralston4, 'Ralston4': Ralston4, - 'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38, } @@ -132,7 +118,7 @@ def register_dde_integrator(name, integrator): """ if name in name2method: raise ValueError(f'"{name}" has been registered in DDE integrators.') - if DDEIntegrator not in integrator.__bases__: + if not issubclass(integrator, DDEIntegrator): raise ValueError(f'"integrator" must be an instance of {DDEIntegrator.__name__}') name2method[name] = integrator diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index a541e3b48..ff62f7b47 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -8,6 +8,7 @@ import jax.numpy as jnp from jax.experimental.host_callback import id_tap +from brainpy import check import brainpy.math as bm from brainpy.errors import UnsupportedError from brainpy.integrators.constants import DT @@ -150,7 +151,8 @@ def _integral_func(self, *args, **kwargs): # format arguments all_args = format_args(args, kwargs, self.arguments) dt = all_args.pop(DT, self.dt) - id_tap(self._check_step, (dt, all_args['t'])) + if check.is_checking(): + id_tap(self._check_step, (dt, all_args['t'])) # derivative values devs = self.f(**all_args) @@ -360,7 +362,8 @@ def _integral_func(self, *args, **kwargs): # format arguments all_args = format_args(args, kwargs, self.arguments) dt = all_args.pop(DT, self.dt) - id_tap(self._check_step, (dt, all_args['t'])) + if check.is_checking(): + id_tap(self._check_step, (dt, all_args['t'])) # derivative values devs = self.f(**all_args) diff --git a/brainpy/integrators/fde/generic.py b/brainpy/integrators/fde/generic.py index 7aadcd7b5..07d6b17dc 100644 --- a/brainpy/integrators/fde/generic.py +++ b/brainpy/integrators/fde/generic.py @@ -82,7 +82,7 @@ def register_fde_integrator(name, integrator): """ if name in name2method: raise ValueError(f'"{name}" has been registered in ODE integrators.') - if FDEIntegrator not in integrator.__bases__: + if not issubclass(integrator, FDEIntegrator): raise ValueError(f'"integrator" must be an instance of {FDEIntegrator.__name__}') name2method[name] = integrator diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py index 59c0ebdb3..b98f6c9dd 100644 --- a/brainpy/integrators/joint_eq.py +++ b/brainpy/integrators/joint_eq.py @@ -153,7 +153,7 @@ def __init__(self, eqs): for par in args[len(vars) + 1:]: if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars): all_arg_pars.append(par) - for key, value in kwargs.values(): + for key, value in kwargs.items(): if key in all_kwarg_pars and value != all_kwarg_pars[key]: raise errors.DiffEqError(f'We got two different default value of "{key}": ' f'{all_kwarg_pars[key]} != {value}') diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py index d2f4d0062..40462bcb5 100644 --- a/brainpy/integrators/ode/adaptive_rk.py +++ b/brainpy/integrators/ode/adaptive_rk.py @@ -58,6 +58,7 @@ from brainpy.integrators import constants as C, utils from brainpy.integrators.ode import common from brainpy.integrators.ode.base import ODEIntegrator +from .generic import register_ode_integrator __all__ = [ 'AdaptiveRKIntegrator', @@ -239,6 +240,9 @@ class RKF12(AdaptiveRKIntegrator): C = [0, 0.5, 1] +register_ode_integrator('rkf12', RKF12) + + class RKF45(AdaptiveRKIntegrator): r"""The Runge–Kutta–Fehlberg method for ODEs. @@ -285,6 +289,9 @@ class RKF45(AdaptiveRKIntegrator): C = [0, 0.25, 0.375, '12/13', 1, '1/3'] +register_ode_integrator('rkf45', RKF45) + + class DormandPrince(AdaptiveRKIntegrator): r"""The Dormand–Prince method for ODEs. @@ -336,6 +343,9 @@ class DormandPrince(AdaptiveRKIntegrator): C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1] +register_ode_integrator('rkdp', DormandPrince) + + class CashKarp(AdaptiveRKIntegrator): r"""The Cash–Karp method for ODEs. @@ -384,6 +394,9 @@ class CashKarp(AdaptiveRKIntegrator): C = [0, 0.2, 0.3, 0.6, 1, 0.875] +register_ode_integrator('ck', CashKarp) + + class BogackiShampine(AdaptiveRKIntegrator): r"""The Bogacki–Shampine method for ODEs. @@ -427,6 +440,9 @@ class BogackiShampine(AdaptiveRKIntegrator): C = [0, 0.5, 0.75, 1] +register_ode_integrator('bs', BogackiShampine) + + class HeunEuler(AdaptiveRKIntegrator): r"""The Heun–Euler method for ODEs. @@ -457,6 +473,9 @@ class HeunEuler(AdaptiveRKIntegrator): C = [0, 1] +register_ode_integrator('heun_euler', HeunEuler) + + class DOP853(AdaptiveRKIntegrator): # def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None): r"""The DOP853 method for ODEs. @@ -484,9 +503,12 @@ class BoSh3(AdaptiveRKIntegrator): """ A = [(), - (0.5, ), + (0.5,), (0.0, 0.75), ('2/9', '1/3', '4/9')] B1 = ['2/9', '1/3', '4/9', 0.0] - B2 = ['-5/72', 1/12, '1/9', '-1/8'] + B2 = ['-5/72', 1 / 12, '1/9', '-1/8'] C = [0., 0.5, 0.75, 1.0] + + +register_ode_integrator('BoSh3', BoSh3) diff --git a/brainpy/integrators/ode/explicit_rk.py b/brainpy/integrators/ode/explicit_rk.py index 3d71a0ea7..ee54b9005 100644 --- a/brainpy/integrators/ode/explicit_rk.py +++ b/brainpy/integrators/ode/explicit_rk.py @@ -70,6 +70,7 @@ from brainpy.integrators import constants as C, utils from brainpy.integrators.ode import common from brainpy.integrators.ode.base import ODEIntegrator +from .generic import register_ode_integrator __all__ = [ 'ExplicitRKIntegrator', @@ -247,6 +248,9 @@ class Euler(ExplicitRKIntegrator): C = [0] +register_ode_integrator('euler', Euler) + + class MidPoint(ExplicitRKIntegrator): r"""Explicit midpoint method for ODEs. @@ -341,6 +345,9 @@ class MidPoint(ExplicitRKIntegrator): C = [0, 0.5] +register_ode_integrator('midpoint', MidPoint) + + class Heun2(ExplicitRKIntegrator): r"""Heun's method for ODEs. @@ -406,6 +413,9 @@ class Heun2(ExplicitRKIntegrator): C = [0, 1] +register_ode_integrator('heun2', Heun2) + + class Ralston2(ExplicitRKIntegrator): r"""Ralston's method for ODEs. @@ -437,6 +447,9 @@ class Ralston2(ExplicitRKIntegrator): C = [0, '2/3'] +register_ode_integrator('ralston2', Ralston2) + + class RK2(ExplicitRKIntegrator): r"""Generic second order Runge-Kutta method for ODEs. @@ -560,6 +573,9 @@ def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, show_code=F super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code) +register_ode_integrator('rk2', RK2) + + class RK3(ExplicitRKIntegrator): r"""Classical third-order Runge-Kutta method for ODEs. @@ -598,6 +614,9 @@ class RK3(ExplicitRKIntegrator): C = [0, 0.5, 1] +register_ode_integrator('rk3', RK3) + + class Heun3(ExplicitRKIntegrator): r"""Heun's third-order method for ODEs. @@ -622,6 +641,9 @@ class Heun3(ExplicitRKIntegrator): C = [0, '1/3', '2/3'] +register_ode_integrator('heun3', Heun3) + + class Ralston3(ExplicitRKIntegrator): r"""Ralston's third-order method for ODEs. @@ -651,6 +673,9 @@ class Ralston3(ExplicitRKIntegrator): C = [0, 0.5, 0.75] +register_ode_integrator('ralston3', Ralston3) + + class SSPRK3(ExplicitRKIntegrator): r"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). @@ -674,6 +699,9 @@ class SSPRK3(ExplicitRKIntegrator): C = [0, 1, 0.5] +register_ode_integrator('ssprk3', SSPRK3) + + class RK4(ExplicitRKIntegrator): r"""Classical fourth-order Runge-Kutta method for ODEs. @@ -741,6 +769,9 @@ class RK4(ExplicitRKIntegrator): C = [0, 0.5, 0.5, 1] +register_ode_integrator('rk4', RK4) + + class Ralston4(ExplicitRKIntegrator): r"""Ralston's fourth-order method for ODEs. @@ -772,6 +803,9 @@ class Ralston4(ExplicitRKIntegrator): C = [0, .4, .45573725, 1] +register_ode_integrator('ralston4', Ralston4) + + class RK4Rule38(ExplicitRKIntegrator): r"""3/8-rule fourth-order method for ODEs. @@ -811,3 +845,6 @@ class RK4Rule38(ExplicitRKIntegrator): A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] B = [0.125, 0.375, 0.375, 0.125] C = [0, '1/3', '2/3', 1] + + +register_ode_integrator('rk4_38rule', RK4Rule38) diff --git a/brainpy/integrators/ode/exponential.py b/brainpy/integrators/ode/exponential.py index 5042fb803..7a96b0be8 100644 --- a/brainpy/integrators/ode/exponential.py +++ b/brainpy/integrators/ode/exponential.py @@ -113,6 +113,7 @@ from brainpy.integrators import constants as C, utils, joint_eq from brainpy.integrators.analysis_by_ast import separate_variables from brainpy.integrators.ode.base import ODEIntegrator +from .generic import register_ode_integrator try: import sympy @@ -506,6 +507,10 @@ def solve(self, diff_eq, var): return s_df_part +register_ode_integrator('exponential_euler', ExponentialEuler) +register_ode_integrator('exp_euler', ExponentialEuler) + + class ExpEulerAuto(ODEIntegrator): """Exponential Euler method using automatic differentiation. @@ -762,3 +767,7 @@ def integral(*args, **kwargs): return args[0] + dt * phi * derivative return [(integral, vars, pars), ] + + +register_ode_integrator('exp_euler_auto', ExpEulerAuto) +register_ode_integrator('exp_auto', ExpEulerAuto) diff --git a/brainpy/integrators/ode/generic.py b/brainpy/integrators/ode/generic.py index 9e1afeb37..50d3014ee 100644 --- a/brainpy/integrators/ode/generic.py +++ b/brainpy/integrators/ode/generic.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- from .base import ODEIntegrator -from .adaptive_rk import * -from .explicit_rk import * -from .exponential import * __all__ = [ 'odeint', @@ -14,31 +11,6 @@ ] name2method = { - # explicit RK - 'euler': Euler, 'Euler': Euler, - 'midpoint': MidPoint, 'MidPoint': MidPoint, - 'heun2': Heun2, 'Heun2': Heun2, - 'ralston2': Ralston2, 'Ralston2': Ralston2, - 'rk2': RK2, 'RK2': RK2, - 'rk3': RK3, 'RK3': RK3, - 'heun3': Heun3, 'Heun3': Heun3, - 'ralston3': Ralston3, 'Ralston3': Ralston3, - 'ssprk3': SSPRK3, 'SSPRK3': SSPRK3, - 'rk4': RK4, 'RK4': RK4, - 'ralston4': Ralston4, 'Ralston4': Ralston4, - 'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38, - - # adaptive RK - 'rkf12': RKF12, 'RKF12': RKF12, - 'rkf45': RKF45, 'RKF45': RKF45, - 'rkdp': DormandPrince, 'dp': DormandPrince, 'DormandPrince': DormandPrince, - 'ck': CashKarp, 'CashKarp': CashKarp, - 'bs': BogackiShampine, 'BogackiShampine': BogackiShampine, - 'heun_euler': HeunEuler, 'HeunEuler': HeunEuler, - - # exponential integrators - 'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler, - 'exp_euler_auto': ExpEulerAuto, 'exp_auto': ExpEulerAuto, 'ExpEulerAuto': ExpEulerAuto, } _DEFAULT_DDE_METHOD = 'euler' @@ -134,7 +106,7 @@ def register_ode_integrator(name, integrator): """ if name in name2method: raise ValueError(f'"{name}" has been registered in ODE integrators.') - if ODEIntegrator not in integrator.__bases__: + if not issubclass(integrator, ODEIntegrator): raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}') name2method[name] = integrator diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index dcba74284..39239e481 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -93,7 +93,7 @@ class IntegratorRunner(Runner): >>> dt = 0.01; beta=2.; gamma=1.; tau=2.; n=9.65 >>> mg_eq = lambda x, t, xdelay: (beta * xdelay(t - tau) / (1 + xdelay(t - tau) ** n) >>> - gamma * x) - >>> xdelay = bm.FixedLenDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2) + >>> xdelay = bm.TimeDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2) >>> integral = bp.ddeint(mg_eq, method='rk4', state_delays={'x': xdelay}) >>> runner = bp.integrators.IntegratorRunner( >>> integral, diff --git a/brainpy/integrators/sde/generic.py b/brainpy/integrators/sde/generic.py index 05ffd9c21..36259d296 100644 --- a/brainpy/integrators/sde/generic.py +++ b/brainpy/integrators/sde/generic.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from .base import SDEIntegrator -from .normal import * -from .srk_scalar import * __all__ = [ 'sdeint', @@ -13,15 +11,6 @@ ] name2method = { - 'euler': Euler, 'Euler': Euler, - 'heun': Heun, 'Heun': Heun, - 'milstein': Milstein, 'Milstein': Milstein, - 'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler, - - # RK methods - 'srk1w1': SRK1W1, 'SRK1W1': SRK1W1, - 'srk2w1': SRK2W1, 'SRK2W1': SRK2W1, - 'klpl': KlPl, 'KlPl': KlPl, } _DEFAULT_SDE_METHOD = 'euler' @@ -98,7 +87,7 @@ def register_sde_integrator(name, integrator): """ if name in name2method: raise ValueError(f'"{name}" has been registered in SDE integrators.') - if SDEIntegrator not in integrator.__bases__: + if not issubclass(integrator, SDEIntegrator): raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}') name2method[name] = integrator diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py index d9e9ef739..ff296c050 100644 --- a/brainpy/integrators/sde/normal.py +++ b/brainpy/integrators/sde/normal.py @@ -6,6 +6,7 @@ from brainpy.integrators import constants, utils from brainpy.integrators.analysis_by_ast import separate_variables from brainpy.integrators.sde.base import SDEIntegrator +from .generic import register_sde_integrator try: import sympy @@ -142,6 +143,9 @@ def build(self): func_name=self.func_name) +register_sde_integrator('euler', Euler) + + class Heun(Euler): def __init__(self, f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None): @@ -154,6 +158,9 @@ def __init__(self, f, g, dt=None, name=None, show_code=False, self.build() +register_sde_integrator('heun', Heun) + + class Milstein(SDEIntegrator): def __init__(self, f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None): @@ -238,6 +245,9 @@ def build(self): func_name=self.func_name) +register_sde_integrator('milstein', Milstein) + + class ExponentialEuler(SDEIntegrator): r"""First order, explicit exponential Euler method. @@ -399,3 +409,7 @@ def symbolic_build(self): if hasattr(self.derivative[constants.F], '__self__'): host = self.derivative[constants.F].__self__ self.integral = self.integral.__get__(host, host.__class__) + + +register_sde_integrator('exponential_euler', ExponentialEuler) +register_sde_integrator('exp_euler', ExponentialEuler) diff --git a/brainpy/integrators/sde/srk_scalar.py b/brainpy/integrators/sde/srk_scalar.py index c95164df0..47535ed65 100644 --- a/brainpy/integrators/sde/srk_scalar.py +++ b/brainpy/integrators/sde/srk_scalar.py @@ -2,6 +2,7 @@ from brainpy.integrators import constants, utils from brainpy.integrators.sde.base import SDEIntegrator +from .generic import register_sde_integrator __all__ = [ 'SRK1W1', @@ -175,6 +176,9 @@ def build(self): func_name=self.func_name) +register_sde_integrator('srk1w1', SRK1W1) + + class SRK2W1(SDEIntegrator): r"""Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. @@ -315,6 +319,9 @@ def build(self): func_name=self.func_name) +register_sde_integrator('srk2w1', SRK2W1) + + class KlPl(SDEIntegrator): def __init__(self, f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None): @@ -354,7 +361,7 @@ def build(self): self.code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + {var}_I10/{constants.DT}') self.code_lines.append(f' {var}_g2 = {var}_I11 / dt_sqrt') self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f_H0s1 + ' - f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2') + f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2') self.code_lines.append(' ') # returns @@ -367,3 +374,6 @@ def build(self): code_lines=self.code_lines, show_code=self.show_code, func_name=self.func_name) + + +register_sde_integrator('klpl', KlPl) diff --git a/brainpy/math/delay_vars.py b/brainpy/math/delay_vars.py index 20e990eab..4c1079b9b 100644 --- a/brainpy/math/delay_vars.py +++ b/brainpy/math/delay_vars.py @@ -1,21 +1,24 @@ # -*- coding: utf-8 -*- - +import warnings from typing import Union, Callable, Tuple import jax.numpy as jnp +import numpy as np from jax import vmap from jax.experimental.host_callback import id_tap from jax.lax import cond +from brainpy import check from brainpy import math as bm from brainpy.base.base import Base +from brainpy.errors import UnsupportedError from brainpy.tools.checking import check_float from brainpy.tools.others import to_size -from brainpy.errors import UnsupportedError __all__ = [ 'AbstractDelay', + 'TimeDelay', 'FixedLenDelay', 'NeutralDelay', ] @@ -32,13 +35,13 @@ def update(self, time, value): _INTERP_ROUND = 'round' -class FixedLenDelay(AbstractDelay): - """Delay variable which has a fixed delay length. +class TimeDelay(AbstractDelay): + """Delay variable which has a fixed delay time length. For example, we create a delay variable which has a maximum delay length of 1 ms >>> import brainpy.math as bm - >>> delay = bm.FixedLenDelay(bm.zeros(3), delay_len=1., dt=0.1) + >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) >>> delay(-0.5) [-0. -0. -0.] @@ -46,13 +49,13 @@ class FixedLenDelay(AbstractDelay): 1. the one-dimensional delay data - >>> delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.2) [-0.2 -0.2 -0.2] 2. the two-dimensional delay data - >>> delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay = bm.TimeDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.6) [[-0.6 -0.6] [-0.6 -0.6] @@ -60,7 +63,7 @@ class FixedLenDelay(AbstractDelay): 3. the three-dimensional delay data - >>> delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay = bm.TimeDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.6) [[[-0.8] [-0.8]] @@ -113,7 +116,7 @@ def __init__( dtype=None, interp_method='linear_interp', ): - super(FixedLenDelay, self).__init__(name=name) + super(TimeDelay, self).__init__(name=name) # shape self.shape = to_size(shape) @@ -161,6 +164,10 @@ def __init__( else: raise ValueError(f'"before_t0" does not support {type(before_t0)}: before_t0') + self.f = jnp.interp + for dim in range(1, len(self.shape) + 1, 1): + self.f = vmap(self.f, in_axes=(None, None, dim), out_axes=dim - 1) + @property def idx(self): return self._idx @@ -191,36 +198,37 @@ def current_time(self): def _check_time(self, times, transforms): prev_time, current_time = times - current_time = bm.as_device_array(current_time) - prev_time = bm.as_device_array(prev_time) + current_time = np.asarray(current_time, dtype=bm.float_) + prev_time = np.asarray(prev_time, dtype=bm.float_) if prev_time > current_time: raise ValueError(f'\n' f'!!! Error in {self.__class__.__name__}: \n' f'The request time should be less than the ' f'current time {current_time}. But we ' f'got {prev_time} > {current_time}') - lower_time = jnp.asarray(current_time - self.delay_len) + lower_time = np.asarray(current_time - self.delay_len) if prev_time < lower_time: raise ValueError(f'\n' f'!!! Error in {self.__class__.__name__}: \n' f'The request time of the variable should be in ' f'[{lower_time}, {current_time}], but we got {prev_time}') - def __call__(self, prev_time): + def __call__(self, time, indices=None): # check - id_tap(self._check_time, (prev_time, self.current_time)) + if check.is_checking(): + id_tap(self._check_time, (time, self.current_time)) if self._before_type == _FUNC_BEFORE: - return cond(prev_time < self.t0, + return cond(time < self.t0, self._before_t0, self._after_t0, - prev_time) + time) else: - return self._after_t0(prev_time) + return self._after_t0(time) def _after_t0(self, prev_time): diff = self.delay_len - (self.current_time - prev_time) - if isinstance(diff, bm.ndarray): diff = diff.value - + if isinstance(diff, bm.ndarray): + diff = diff.value if self.interp_method == _INTERP_LINEAR: req_num_step = jnp.asarray(diff / self._dt, dtype=bm.get_dint()) extra = diff - req_num_step * self._dt @@ -238,13 +246,10 @@ def _true_fn(self, div_mod): def _false_fn(self, div_mod): req_num_step, extra = div_mod - f = jnp.interp - for dim in range(1, len(self.shape) + 1, 1): - f = vmap(f, in_axes=(None, None, dim), out_axes=dim - 1) idx = jnp.asarray([self.idx[0] + req_num_step, self.idx[0] + req_num_step + 1]) idx %= self.num_delay_step - return f(extra, jnp.asarray([0., self._dt]), self._data[idx]) + return self.f(extra, jnp.asarray([0., self._dt]), self._data[idx]) def update(self, time, value): self._data[self._idx[0]] = value @@ -252,17 +257,32 @@ def update(self, time, value): self._idx.value = (self._idx + 1) % self.num_delay_step -class VariedLenDelay(AbstractDelay): - """Delay variable which has a functional delay - - """ +def FixedLenDelay(shape: Union[int, Tuple[int, ...]], + delay_len: Union[float, int], + before_t0: Union[Callable, bm.ndarray, jnp.ndarray, float, int] = None, + t0: Union[float, int] = 0., + dt: Union[float, int] = None, + name: str = None, + dtype=None, + interp_method='linear_interp', ): + warnings.warn('Please use "brainpy.math.TimeDelay" instead. ' + '"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ', + DeprecationWarning) + return TimeDelay(shape=shape, + delay_len=delay_len, + before_t0=before_t0, + t0=t0, + dt=dt, + name=name, + dtype=dtype, + interp_method=interp_method) + + +class NeutralDelay(TimeDelay): + pass - def update(self, time, value): - pass - def __init__(self): - super(VariedLenDelay, self).__init__() +class LengthDelay(AbstractDelay): + pass -class NeutralDelay(FixedLenDelay): - pass diff --git a/brainpy/math/parallels.py b/brainpy/math/parallels.py index 080b560f9..a8e0de7c2 100644 --- a/brainpy/math/parallels.py +++ b/brainpy/math/parallels.py @@ -27,6 +27,7 @@ from brainpy.base.base import Base from brainpy.base.collector import TensorCollector from brainpy.math.random import RandomState +from brainpy.math.jaxarray import JaxArray from brainpy.tools.codes import change_func_name __all__ = [ @@ -77,7 +78,7 @@ def vmap(func, dyn_vars=None, batched_vars=None, ---------- func : Base, function, callable The function or the module to compile. - dyn_vars : dict + dyn_vars : dict, sequence batched_vars : dict in_axes : optional, int, sequence of int Specify which input array axes to map over. If each positional argument to @@ -207,13 +208,19 @@ def vmap(func, dyn_vars=None, batched_vars=None, axis_name=axis_name) else: + if isinstance(dyn_vars, JaxArray): + dyn_vars = [dyn_vars] + if isinstance(dyn_vars, (tuple, list)): + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} + assert isinstance(dyn_vars, dict) + # dynamical variables - dyn_vars, rand_vars = TensorCollector(), TensorCollector() + _dyn_vars, _rand_vars = TensorCollector(), TensorCollector() for key, val in dyn_vars.items(): if isinstance(val, RandomState): - rand_vars[key] = val + _rand_vars[key] = val else: - dyn_vars[key] = val + _dyn_vars[key] = val # in axes if in_axes is None: @@ -249,8 +256,8 @@ def vmap(func, dyn_vars=None, batched_vars=None, # jit function return _make_vmap(func=func, - dyn_vars=dyn_vars, - rand_vars=rand_vars, + dyn_vars=_dyn_vars, + rand_vars=_rand_vars, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, diff --git a/brainpy/math/tests/test_delay_vars.py b/brainpy/math/tests/test_delay_vars.py index 6eb6e1e87..475651fc4 100644 --- a/brainpy/math/tests/test_delay_vars.py +++ b/brainpy/math/tests/test_delay_vars.py @@ -12,7 +12,7 @@ def test_dim1(self): # linear interp t0 = 0. before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) - delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) + delay = bm.TimeDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10)) self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 9.5)) print() @@ -21,8 +21,8 @@ def test_dim1(self): # self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones(10) * 8.7)) # round interp - delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0, - interp_method='round') + delay = bm.TimeDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0, + interp_method='round') self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10)) self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 10)) self.assertTrue(bm.array_equal(delay(t0 - 0.2), bm.ones(10) * 9)) @@ -31,7 +31,7 @@ def test_dim2(self): t0 = 0. before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) - delay = bm.FixedLenDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) + delay = bm.TimeDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 10)) self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 9.5)) # self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5)) * 8.7)) @@ -41,27 +41,27 @@ def test_dim3(self): before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) before_t0 = bm.repeat(before_t0.reshape((11, 10, 5, 1)), 3, axis=3) - delay = bm.FixedLenDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) + delay = bm.TimeDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 10)) self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 9.5)) # self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 8.7)) def test1(self): print() - delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) + delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) print(delay(-0.2)) - delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) + delay = bm.TimeDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) print(delay(-0.6)) - delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) + delay = bm.TimeDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) print(delay(-0.8)) def test_current_time2(self): print() - delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) + delay = bm.TimeDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) print(delay(0.)) before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) - delay = bm.FixedLenDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0) + delay = bm.TimeDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0) print(delay(0.)) # def test_prev_time_beyond_boundary(self): diff --git a/brainpy/measure/correlation.py b/brainpy/measure/correlation.py index 27a77b8db..5514a56b6 100644 --- a/brainpy/measure/correlation.py +++ b/brainpy/measure/correlation.py @@ -186,7 +186,6 @@ def matrix_correlation(x, y): return cc -@jit def functional_connectivity(activities): """Functional connectivity matrix of timeseries activities. @@ -200,12 +199,12 @@ def functional_connectivity(activities): connectivity_matrix: tensor ``num_sample x num_sample`` functional connectivity matrix. """ - activities = bm.asarray(activities) + activities = bm.as_numpy(activities) if activities.ndim != 2: raise ValueError('Only support 2d tensor with shape of "(num_time, num_sample)". ' f'But we got a tensor with the shape of {activities.shape}') - fc = bm.corrcoef(activities.T) - return bm.nan_to_num(fc) + fc = np.corrcoef(activities.T) + return np.nan_to_num(fc) @jit diff --git a/brainpy/nn/base.py b/brainpy/nn/base.py index 972386cdb..4bf5d5ddc 100644 --- a/brainpy/nn/base.py +++ b/brainpy/nn/base.py @@ -352,8 +352,8 @@ def output_shape(self, size): @property def is_feedback_input_supported(self): - if hasattr(self.init_fb, 'not_implemented'): - if self.init_fb.not_implemented: + if hasattr(self.init_fb_conn, 'not_implemented'): + if self.init_fb_conn.not_implemented: return False return True @@ -405,10 +405,10 @@ def copy(self, new_obj.name = self.unique_name(name or (self.name + '_copy')) return new_obj - def _init_ff(self): + def _init_ff_conn(self): if not self._is_ff_initialized: try: - self.init_ff() + self.init_ff_conn() except Exception as e: raise ModelBuildError(f'{self.name} initialization failed.') from e self._is_ff_initialized = True @@ -416,43 +416,48 @@ def _init_ff(self): raise ValueError(f'Please set the output shape when implementing ' f'"init_ff()" of the node {self.name}') - def _init_fb(self): + def _init_fb_conn(self): if not self._is_fb_initialized: try: - self.init_fb() + self.init_fb_conn() except Exception as e: raise ModelBuildError(f"{self.name} initialization failed.") from e self._is_fb_initialized = True @not_implemented - def init_fb(self): + def init_fb_conn(self): """Initialize the feedback connections. This function will be called only once.""" raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.') - def init_ff(self): + def init_ff_conn(self): """Initialize the feedforward connections. This function will be called only once.""" raise NotImplementedError('Please implement the feedforward initialization.') - def init_state(self, num_batch=1): - """Initialize the node state. + def _init_state(self, num_batch=1): + state = self.init_state(num_batch) + if state is not None: + self.set_state(state) + + def _init_fb_output(self, num_batch=1): + output = self.init_fb_output(num_batch) + if output is not None: + self.set_fb_output(output) + + def init_state(self, num_batch=1) -> Optional[Tensor]: + """Set the initial node state. + This function can be called multiple times.""" pass - def init_fb_output(self, num_batch=1): - """Initialize the node state for feedback. + def init_fb_output(self, num_batch=1) -> Optional[Tensor]: + """Set the initial node feedback state. This function can be called multiple times. However, it is only triggered when the node has feedback connections. - - Parameters - ---------- - num_batch: int - The batch size. """ - state = bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) - self.set_fb_output(state) + return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) def initialize(self, num_batch: int): """ @@ -470,17 +475,17 @@ def initialize(self, num_batch: int): '1. Connecting an instance of "brainpy.nn.Input()" to this node. \n' '2. Providing the "input_shape" when initialize the node.') check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False) - self._init_ff() + self._init_ff_conn() # initialize state - self.init_state(num_batch) + self._init_state(num_batch) self._is_state_initialized = True if self.feedback_shapes is not None: # feedback initialization - self._init_fb() + self._init_fb_conn() # initialize feedback state - self.init_fb_output(num_batch) + self._init_fb_output(num_batch) self._is_fb_state_initialized = True def _check_inputs(self, ff, fb=None): @@ -865,18 +870,25 @@ def set_output_shape(self, shape: Dict[str, Sequence[int]]): for val in shape.values(): check_batch_shape(val, self.output_shape) - def init_ff(self): + def init_ff_conn(self): """Initialize the feedforward connections of the network. This function will be called only once.""" # input shapes of entry nodes for node in self.entry_nodes: + # set ff shapes if node.feedforward_shapes is None: if self.feedforward_shapes is None: raise ValueError('Cannot find the input size. ' 'Cannot initialize the network.') else: node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]}) - node._init_ff() + # set fb shapes + if node in self.fb_senders: + fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} + if None not in fb_shapes.values(): + node.set_feedback_shapes(fb_shapes) + # init ff conn + node._init_ff_conn() # initialize the data children_queue = [] @@ -890,10 +902,16 @@ def init_ff(self): children_queue.append(child) while len(children_queue): node = children_queue.pop(0) - # initialize input and output sizes + # set ff shapes parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])} node.set_feedforward_shapes(parent_sizes) - node._init_ff() + if node in self.fb_senders: + # set fb shapes + fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} + if None not in fb_shapes.values(): + node.set_feedback_shapes(fb_shapes) + # init ff conn + node._init_ff_conn() # append children for child in self.ff_receivers.get(node, []): ff_senders[child].remove(node) @@ -904,28 +922,37 @@ def init_ff(self): out_sizes = {node: node.output_shape for node in self.exit_nodes} self.set_output_shape(out_sizes) - def init_fb(self): + def init_fb_conn(self): """Initialize the feedback connections of the network. This function will be called only once.""" for receiver, senders in self.fb_senders.items(): fb_sizes = {node: node.output_shape for node in senders} + if None in fb_sizes.values(): + none_size_nodes = [repr(n) for n, v in fb_sizes.items() if v is None] + none_size_nodes = "\n".join(none_size_nodes) + raise ValueError(f'Output shapes of nodes \n\n' + f'{none_size_nodes}\n\n' + f'have not been initialized, ' + f'leading us cannot initialize the ' + f'feedback connection of node \n\n' + f'{receiver}') receiver.set_feedback_shapes(fb_sizes) - receiver._init_fb() + receiver._init_fb_conn() - def init_state(self, num_batch=1): + def _init_state(self, num_batch=1): """Initialize the states of all children nodes. This function can be called multiple times.""" for node in self.lnodes: - node.init_state(num_batch) + node._init_state(num_batch) - def init_fb_output(self, num_batch=1): + def _init_fb_output(self, num_batch=1): """Initialize the node feedback state. This function can be called multiple times. However, it is only triggered when the node has feedback connections. """ for node in self.feedback_nodes: - node.init_fb_output(num_batch) + node._init_fb_output(num_batch) def initialize(self, num_batch: int): """ @@ -974,11 +1001,11 @@ def initialize(self, num_batch: int): if self.feedforward_shapes is None: raise ValueError('Cannot initialize this node, because we detect ' 'both "feedforward_shapes" is None. ') - check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False) - self._init_ff() + check_integer(num_batch, 'num_batch', min_bound=1, allow_none=False) + self._init_ff_conn() # initialize state - self.init_state(num_batch) + self._init_state(num_batch) self._is_state_initialized = True # set feedback shapes @@ -991,10 +1018,10 @@ def initialize(self, num_batch: int): # feedback initialization if self.feedback_shapes is not None: - self._init_fb() + self._init_fb_conn() # initialize feedback state - self.init_fb_output(num_batch) + self._init_fb_output(num_batch) self._is_fb_state_initialized = True def _check_inputs(self, ff, fb=None): diff --git a/brainpy/nn/nodes/ANN/conv.py b/brainpy/nn/nodes/ANN/conv.py index d74f02b43..84a77c9bf 100644 --- a/brainpy/nn/nodes/ANN/conv.py +++ b/brainpy/nn/nodes/ANN/conv.py @@ -81,7 +81,7 @@ def __init__(self, num_input, num_output, kernel_size, strides=1, dilations=1, self.padding = padding self.groups = groups - def init_ff(self): + def init_ff_conn(self): assert self.num_input % self.groups == 0, '"nin" should be divisible by groups' size = _check_tuple(self.kernel_size) + (self.num_input // self.groups, self.num_output) self.w = init_param(self.w_init, size) diff --git a/brainpy/nn/nodes/ANN/dropout.py b/brainpy/nn/nodes/ANN/dropout.py index 567d6e13b..bbf4e24c5 100644 --- a/brainpy/nn/nodes/ANN/dropout.py +++ b/brainpy/nn/nodes/ANN/dropout.py @@ -44,7 +44,7 @@ def __init__(self, prob, seed=None, **kwargs): self.prob = prob self.rng = bm.random.RandomState(seed=seed) - def init_ff(self): + def init_ff_conn(self): self.set_output_shape(self.feedforward_shapes) def forward(self, ff, **shared_kwargs): diff --git a/brainpy/nn/nodes/ANN/rnn_cells.py b/brainpy/nn/nodes/ANN/rnn_cells.py index ff51de426..ef1d20acd 100644 --- a/brainpy/nn/nodes/ANN/rnn_cells.py +++ b/brainpy/nn/nodes/ANN/rnn_cells.py @@ -45,7 +45,9 @@ def __init__( self.num_unit = num_unit check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) + self.set_output_shape((None, self.num_unit)) + # initializers self._state_initializer = state_initializer self._wi_initializer = wi_initializer self._wh_initializer = wh_initializer @@ -55,23 +57,23 @@ def __init__( check_initializer(state_initializer, 'state_initializer', allow_none=False) check_initializer(bias_initializer, 'bias_initializer', allow_none=True) + # activation function self.activation = bm.activations.get(activation) - def init_ff(self): + def init_ff_conn(self): unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_input = sum(free_sizes) - self.set_output_shape(unique_size + (self.num_unit,)) # weights + num_input = sum(free_sizes) self.Wff = init_param(self._wi_initializer, (num_input, self.num_unit)) self.Wrec = init_param(self._wh_initializer, (self.num_unit, self.num_unit)) - self.bff = init_param(self._bias_initializer, (self.num_unit,)) + self.bias = init_param(self._bias_initializer, (self.num_unit,)) if self.trainable: self.Wff = bm.TrainVar(self.Wff) self.Wrec = bm.TrainVar(self.Wrec) - self.bff = None if (self.bff is None) else bm.TrainVar(self.bff) + self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - def init_fb(self): + def init_fb_conn(self): unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' num_feedback = sum(free_sizes) @@ -81,29 +83,23 @@ def init_fb(self): self.Wfb = bm.TrainVar(self.Wfb) def init_state(self, num_batch=1): - state = init_param(self._state_initializer, (num_batch, self.num_unit)) - self.set_state(state) + return init_param(self._state_initializer, (num_batch, self.num_unit)) def forward(self, ff, fb=None, **shared_kwargs): ff = bm.concatenate(ff, axis=-1) h = ff @ self.Wff h += self.state.value @ self.Wrec - if self.bff is not None: - h += self.bff + if self.bias is not None: + h += self.bias if fb is not None: fb = bm.concatenate(fb, axis=-1) h += fb @ self.Wfb self.state.value = self.activation(h) return self.state.value - def init_fb_output(self, num_batch=1): - state = init_param(self._state_initializer, (num_batch, self.num_unit)) - self.set_fb_output(state) - class GRU(RecurrentNode): - r""" - Gated Recurrent Unit. + r"""Gated Recurrent Unit. The implementation is based on (Chung, et al., 2014) [1]_ with biases. @@ -145,6 +141,7 @@ def __init__( self.num_unit = num_unit check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) + self.set_output_shape((None, self.num_unit)) self._wi_initializer = wi_initializer self._wh_initializer = wh_initializer @@ -155,13 +152,13 @@ def __init__( check_initializer(state_initializer, 'state_initializer', allow_none=False) check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - def init_ff(self): + def init_ff_conn(self): # data shape unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_input = sum(free_sizes) - self.set_output_shape(unique_size + (self.num_unit,)) + # weights + num_input = sum(free_sizes) self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 3)) self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3)) self.bias = init_param(self._bias_initializer, (self.num_unit * 3,)) @@ -170,7 +167,7 @@ def init_ff(self): self.Wh = bm.TrainVar(self.Wh) self.bias = bm.TrainVar(self.bias) if (self.bias is not None) else None - def init_fb(self): + def init_fb_conn(self): unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' num_feedback = sum(free_sizes) @@ -180,8 +177,7 @@ def init_fb(self): self.Wi_fb = bm.TrainVar(self.Wi_fb) def init_state(self, num_batch=1): - state = init_param(self._state_initializer, (num_batch, self.num_unit)) - self.set_state(state) + return init_param(self._state_initializer, (num_batch, self.num_unit)) def forward(self, ff, fb=None, **shared_kwargs): gates_x = bm.matmul(bm.concatenate(ff, axis=-1), self.Wi_ff) @@ -205,10 +201,6 @@ def forward(self, ff, fb=None, **shared_kwargs): self.state.value = next_state return next_state - def init_fb_output(self, num_batch=1): - state = init_param(self._state_initializer, (num_batch, self.num_unit)) - self.set_fb_output(state) - class LSTM(RecurrentNode): r"""Long short-term memory (LSTM) RNN core. @@ -264,6 +256,7 @@ def __init__( self.num_unit = num_unit check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) + self.set_output_shape((None, self.num_unit,)) self._state_initializer = state_initializer self._wi_initializer = wi_initializer @@ -274,13 +267,12 @@ def __init__( check_initializer(bias_initializer, 'bias_initializer', allow_none=True) check_initializer(state_initializer, 'state_initializer', allow_none=False) - def init_ff(self): + def init_ff_conn(self): # data shape unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' - num_input = sum(free_sizes) - self.set_output_shape(unique_size + (self.num_unit,)) # weights + num_input = sum(free_sizes) self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 4)) self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 4)) self.bias = init_param(self._bias_initializer, (self.num_unit * 4,)) @@ -289,7 +281,7 @@ def init_ff(self): self.Wh = bm.TrainVar(self.Wh) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - def init_fb(self): + def init_fb_conn(self): unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) assert len(unique_size) == 1, 'Only support data with or without batch size.' num_feedback = sum(free_sizes) @@ -299,8 +291,7 @@ def init_fb(self): self.Wi_fb = bm.TrainVar(self.Wi_fb) def init_state(self, num_batch=1): - hc = init_param(self._state_initializer, (num_batch * 2, self.num_unit)) - self.set_state(hc) + return init_param(self._state_initializer, (num_batch * 2, self.num_unit)) def forward(self, ff, fb=None, **shared_kwargs): h, c = bm.split(self.state, 2) @@ -316,10 +307,6 @@ def forward(self, ff, fb=None, **shared_kwargs): self.state.value = bm.vstack([h, c]) return h - def init_fb_output(self, num_batch=1): - state = init_param(self._state_initializer, (num_batch, self.num_unit)) - self.set_fb_output(state) - @property def h(self): """Hidden state.""" diff --git a/brainpy/nn/nodes/RC/linear_readout.py b/brainpy/nn/nodes/RC/linear_readout.py index 8f877e088..40cae1db1 100644 --- a/brainpy/nn/nodes/RC/linear_readout.py +++ b/brainpy/nn/nodes/RC/linear_readout.py @@ -39,12 +39,12 @@ def __init__( super(LinearReadout, self).__init__(num_unit=num_unit, weight_initializer=weight_initializer, bias_initializer=bias_initializer, **kwargs) def init_state(self, num_batch=1): - state = bm.Variable(bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)) - self.set_state(state) + return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) def forward(self, ff, fb=None, **shared_kwargs): - self.state.value = super(LinearReadout, self).forward(ff, fb=fb, **shared_kwargs) - return self.state + h = super(LinearReadout, self).forward(ff, fb=fb, **shared_kwargs) + self.state.value = h + return h def __force_init__(self, train_pars: Optional[Dict] = None): if train_pars is None: train_pars = dict() @@ -76,4 +76,4 @@ def __force_train__(self, # update the weights e = bm.atleast_2d(self.state - target) # (1, num_output) dw = bm.dot(-c * k, e) # (num_hidden, num_output) - self.weights += dw + self.Wff += dw diff --git a/brainpy/nn/nodes/RC/nvar.py b/brainpy/nn/nodes/RC/nvar.py index f09e65aef..83d9546b0 100644 --- a/brainpy/nn/nodes/RC/nvar.py +++ b/brainpy/nn/nodes/RC/nvar.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from itertools import combinations_with_replacement -from typing import Union +from typing import Union, Sequence import numpy as np @@ -9,7 +9,8 @@ from brainpy.nn.base import RecurrentNode from brainpy.tools.checking import (check_shape_consistency, check_float, - check_integer) + check_integer, + check_sequence) __all__ = [ 'NVAR' @@ -46,7 +47,7 @@ class NVAR(RecurrentNode): ---------- delay: int The number of delay step. - order: int + order: int, sequence of int The nonlinear order. stride: int The stride to sample linear part vector in the delays. @@ -63,51 +64,58 @@ class NVAR(RecurrentNode): def __init__(self, delay: int, - order: int, + order: Union[int, Sequence[int]], stride: int = 1, constant: Union[float, int] = None, **kwargs): super(NVAR, self).__init__(**kwargs) - self.delay = delay + if not isinstance(order, (tuple, list)): + order = [order] self.order = order - self.stride = stride - self.constant = constant + check_sequence(order, 'order', elem_type=int, allow_none=False) + self.delay = delay check_integer(delay, 'delay', allow_none=False) - check_integer(order, 'order', allow_none=False) + self.stride = stride check_integer(stride, 'stride', allow_none=False) + self.constant = constant check_float(constant, 'constant', allow_none=True, allow_int=True) - def init_ff(self): + self.comb_ids = [] + # delay variables + self.num_delay = self.delay * self.stride + self.idx = bm.Variable(bm.array([0], dtype=bm.uint32)) + self.store = None + + def init_ff_conn(self): + """Initialize feedforward connections.""" # input dimension batch_size, free_size = check_shape_consistency(self.feedforward_shapes, -1, True) self.input_dim = sum(free_size) assert batch_size == (None,), f'batch_size must be None, but got {batch_size}' - # linear dimension linear_dim = self.delay * self.input_dim - # for each monomial created in the non linear part, indices + # for each monomial created in the non-linear part, indices # of the n components involved, n being the order of the # monomials. Precompute them to improve efficiency. - idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), self.order))) - self.comb_ids = bm.asarray(idx) - # number of non linear components is (d + n - 1)! / (d - 1)! n! + for order in self.order: + idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), order))) + self.comb_ids = bm.asarray(idx) + # number of non-linear components is (d + n - 1)! / (d - 1)! n! # i.e. number of all unique monomials of order n made from the # linear components. - nonlinear_dim = len(self.comb_ids) + nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) # output dimension - output_dim = int(linear_dim + nonlinear_dim) + self.output_dim = int(linear_dim + nonlinear_dim) if self.constant is not None: - output_dim += 1 - self.set_output_shape((None, output_dim)) - - # delay variables - self.num_delay = self.delay * self.stride - self.idx = bm.Variable(bm.array([0], dtype=bm.uint32)) - self.store = None + self.output_dim += 1 + self.set_output_shape((None, self.output_dim)) def init_state(self, num_batch=1): - # to store the k*s last inputs, k being the delay and s the strides + """Initialize the node state which depends on batch size.""" + # To store the last inputs. + # Note, the batch axis is not in the first dimension, so we + # manually handle the state of NVAR, rather return it. state = bm.zeros((self.num_delay, num_batch, self.input_dim), dtype=bm.float_) if self.store is None: self.store = bm.Variable(state) @@ -115,7 +123,8 @@ def init_state(self, num_batch=1): self.store.value = state def forward(self, ff, fb=None, **shared_kwargs): - # 1. store the current input + all_parts = [] + # 1. Store the current input ff = bm.concatenate(ff, axis=-1) self.store[self.idx[0]] = ff self.idx.value = (self.idx + 1) % self.num_delay @@ -124,12 +133,15 @@ def forward(self, ff, fb=None, **shared_kwargs): select_ids = (self.idx[0] + bm.arange(self.num_delay)[::self.stride]) % self.num_delay linear_parts = bm.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) linear_parts = bm.reshape(linear_parts, (linear_parts.shape[0], -1)) + # 3. constant + if self.constant is not None: + constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,)) + all_parts.append(constant) + all_parts.append(linear_parts) # 3. Nonlinear part: # select monomial terms and compute them - nonlinear_parts = bm.prod(linear_parts[:, self.comb_ids], axis=2) - if self.constant is None: - return bm.concatenate([linear_parts, nonlinear_parts], axis=-1) - else: - constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,)) - return bm.concatenate([constant, linear_parts, nonlinear_parts], axis=-1) + for ids in self.comb_ids: + all_parts.append(bm.prod(linear_parts[:, ids], axis=2)) + # 4. Return all parts + return bm.concatenate(all_parts, axis=-1) diff --git a/brainpy/nn/nodes/RC/reservoir.py b/brainpy/nn/nodes/RC/reservoir.py index 3b73d36fa..c137501ab 100644 --- a/brainpy/nn/nodes/RC/reservoir.py +++ b/brainpy/nn/nodes/RC/reservoir.py @@ -158,7 +158,7 @@ def __init__( self.noise_type = noise_type check_string(noise_type, 'noise_type', ['normal', 'uniform']) - def init_ff(self): + def init_ff_conn(self): """Initialize feedforward connections, weights, and variables.""" unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) self.set_output_shape(unique_shape + (self.num_unit,)) @@ -197,10 +197,9 @@ def init_ff(self): def init_state(self, num_batch=1): # initialize internal state - state = bm.Variable(bm.zeros((num_batch, self.num_unit), dtype=bm.float_)) - self.set_state(state) + return bm.zeros((num_batch, self.num_unit), dtype=bm.float_) - def init_fb(self): + def init_fb_conn(self): """Initialize feedback connections, weights, and variables.""" if self.feedback_shapes is not None: unique_shape, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) diff --git a/brainpy/nn/nodes/base/activation.py b/brainpy/nn/nodes/base/activation.py index fc02c4859..7af429bf6 100644 --- a/brainpy/nn/nodes/base/activation.py +++ b/brainpy/nn/nodes/base/activation.py @@ -37,7 +37,7 @@ def __init__(self, self._fun_setting = dict() if (fun_setting is None) else fun_setting assert isinstance(self._fun_setting, dict), '"fun_setting" must be a dict.' - def init_ff(self): + def init_ff_conn(self): self.set_output_shape(self.feedforward_shapes) def forward(self, ff, **shared_kwargs): diff --git a/brainpy/nn/nodes/base/dense.py b/brainpy/nn/nodes/base/dense.py index cfa48cb1d..0f8ea2067 100644 --- a/brainpy/nn/nodes/base/dense.py +++ b/brainpy/nn/nodes/base/dense.py @@ -8,8 +8,8 @@ from brainpy import math as bm from brainpy.errors import UnsupportedError, MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer -from brainpy.nn import utils from brainpy.nn.base import Node +from brainpy.nn.utils import init_param from brainpy.tools.checking import (check_shape_consistency, check_initializer) from brainpy.types import Tensor @@ -49,37 +49,60 @@ def __init__( **kwargs ): super(Dense, self).__init__(trainable=trainable, **kwargs) + + # shape self.num_unit = num_unit if num_unit < 0: raise ValueError(f'Received an invalid value for `num_unit`, expected ' f'a positive integer. Received: num_unit={num_unit}') + + # weight initializer self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer check_initializer(weight_initializer, 'weight_initializer') check_initializer(bias_initializer, 'bias_initializer', allow_none=True) - def init_ff(self): + # weights + self.Wff = None + self.bias = None + self.Wfb = None + + def init_ff_conn(self): # shapes - in_sizes = [size[1:] for size in self.feedforward_shapes] # remove batch size - unique_shape, free_shapes = check_shape_consistency(in_sizes, -1, True) - weight_shape = (sum(free_shapes), self.num_unit) - bias_shape = (self.num_unit,) - # set output size - self.set_output_shape((None, ) + unique_shape + (self.num_unit,)) + other_size, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) + self._other_size = other_size + # set output size # TODO + self.set_output_shape(other_size + (self.num_unit,)) + # initialize feedforward weights - self.weights = utils.init_param(self.weight_initializer, weight_shape) - self.bias = utils.init_param(self.bias_initializer, bias_shape) + self.Wff = init_param(self.weight_initializer, (sum(free_shapes), self.num_unit)) + self.bias = init_param(self.bias_initializer, (self.num_unit,)) if self.trainable: - self.weights = bm.TrainVar(self.weights) - if self.bias is not None: - self.bias = bm.TrainVar(self.bias) + self.Wff = bm.TrainVar(self.Wff) + self.bias = bm.TrainVar(self.bias) if (self.bias is not None) else None - def forward(self, ff: Sequence[Tensor], **shared_kwargs): - ff = bm.concatenate(ff, axis=-1) - if self.bias is None: - return ff @ self.weights + def init_fb_conn(self): + other_size, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) + if self._other_size != other_size: + raise ValueError(f'The feedback shape {other_size} is not consistent ' + f'with the feedforward shape {self._other_size}') + + # initialize feedforward weights + weight_shapes = (sum(free_shapes), self.num_unit) + if self.trainable: + self.Wfb = bm.TrainVar(init_param(self.weight_initializer, weight_shapes)) else: - return ff @ self.weights + self.bias + self.Wfb = init_param(self.weight_initializer, weight_shapes) + + def forward(self, ff: Sequence[Tensor], fb=None, **shared_kwargs): + ff = bm.concatenate(ff, axis=-1) + res = ff @ self.Wff + if fb is not None: + fb = bm.concatenate(fb, axis=-1) + res += fb @ self.Wfb + if self.bias is not None: + res += self.bias + return res def __ridge_train__(self, ffs: Sequence[Tensor], @@ -93,6 +116,7 @@ def __ridge_train__(self, Also, the element in ``ffs`` should have the same shape. """ + assert self.Wfb is None, 'Currently ridge learning do not support feedback connections.' # parameters if train_pars is None: train_pars = dict() @@ -119,9 +143,9 @@ def __ridge_train__(self, W = bm.linalg.pinv(temp) @ (ffs.T @ targets) # assign trained weights if self.bias is None: - self.weights.value = W + self.Wff.value = W else: - self.weights.value = W[:-1] + self.Wff.value = W[:-1] self.bias.value = W[-1] def __force_init__(self, *args, **kwargs): diff --git a/brainpy/nn/nodes/base/io.py b/brainpy/nn/nodes/base/io.py index d4a098907..2b7e72c2a 100644 --- a/brainpy/nn/nodes/base/io.py +++ b/brainpy/nn/nodes/base/io.py @@ -21,9 +21,9 @@ def __init__(self, name: str = None): super(Input, self).__init__(name=name, input_shape=input_shape) self.set_feedforward_shapes({self.name: (None,) + to_size(input_shape)}) - self._init_ff() + self._init_ff_conn() - def init_ff(self): + def init_ff_conn(self): self.set_output_shape(self.feedforward_shapes) def forward(self, ff, **shared_kwargs): diff --git a/brainpy/nn/nodes/base/ops.py b/brainpy/nn/nodes/base/ops.py index 958f3158e..328714510 100644 --- a/brainpy/nn/nodes/base/ops.py +++ b/brainpy/nn/nodes/base/ops.py @@ -23,7 +23,7 @@ def __init__(self, axis=-1, **kwargs): super(Concat, self).__init__(**kwargs) self.axis = axis - def init_ff(self): + def init_ff_conn(self): unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, self.axis) out_size = list(unique_shape) out_size.insert(self.axis, sum(free_shapes)) @@ -45,7 +45,7 @@ def __init__(self, index, **kwargs): if isinstance(index, int): self.index = bm.asarray([index]).value - def init_ff(self): + def init_ff_conn(self): out_size = bm.zeros(self.feedforward_shapes[1:])[self.index].shape self.set_output_shape((None, ) + out_size) @@ -69,7 +69,7 @@ def __init__(self, shape, **kwargs): self.shape = tools.to_size(shape) assert (None not in self.shape), 'Batch size can not be defined in the reshaped size.' - def init_ff(self): + def init_ff_conn(self): in_size = self.feedforward_shapes[1:] if -1 in self.shape: assert self.shape.count(-1) == 1, f'Cannot set shape with multiple -1. But got {self.shape}' @@ -97,7 +97,7 @@ class Summation(Node): def __init__(self, **kwargs): super(Summation, self).__init__(**kwargs) - def init_ff(self): + def init_ff_conn(self): unique_shape, _ = check_shape_consistency(self.feedforward_shapes, None, True) self.set_output_shape(list(unique_shape)) diff --git a/brainpy/nn/operations.py b/brainpy/nn/operations.py index fbebc44f8..acb18e39b 100644 --- a/brainpy/nn/operations.py +++ b/brainpy/nn/operations.py @@ -365,7 +365,7 @@ def fb_connect( all_nodes, all_ff_edges, all_fb_edges, fb_senders, fb_receivers = _retrieve_nodes_and_edges(senders, receivers) - # detect whether the node implement its own "init_fb()" function + # detect whether the node implement its own "init_fb_conn()" function for node in fb_receivers: if not node.is_feedback_input_supported: raise ValueError(f'Establish a feedback connection to \n' diff --git a/brainpy/tools/checking.py b/brainpy/tools/checking.py index add5080b9..60a255129 100644 --- a/brainpy/tools/checking.py +++ b/brainpy/tools/checking.py @@ -20,6 +20,7 @@ 'check_float', 'check_integer', 'check_string', + 'check_sequence', ] @@ -209,6 +210,25 @@ def check_connector(connector: Union[Callable, conn.Connector, Tensor], f'tensor or callable function. While we got {type(connector)}') +def check_sequence(value: Sequence, + name=None, + elem_type=None, + allow_none=True): + if name is None: name = '' + if value is None: + if allow_none: + return + else: + raise ValueError(f'{name} must be a sequence, but got None') + if not isinstance(value, (tuple, list)): + raise ValueError(f'{name} should be a sequence, but we got a {type(value)}') + if elem_type is not None: + for v in value: + if not isinstance(v, elem_type): + raise ValueError(f'Elements in {name} should be {elem_type}, ' + f'but we got {type(elem_type)}: {v}') + + def check_float(value: float, name=None, min_bound=None, max_bound=None, allow_none=False, allow_int=True): """Check float type. diff --git a/brainpy/tools/others/__init__.py b/brainpy/tools/others/__init__.py index 381d79938..30fe2ea5a 100644 --- a/brainpy/tools/others/__init__.py +++ b/brainpy/tools/others/__init__.py @@ -3,3 +3,4 @@ from .ast2code import * from .dicts import * from .others import * +from .numba_jit import * diff --git a/brainpy/tools/others/numba_jit.py b/brainpy/tools/others/numba_jit.py new file mode 100644 index 000000000..062eadfdc --- /dev/null +++ b/brainpy/tools/others/numba_jit.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +try: + from numba import njit +except (ImportError, ModuleNotFoundError): + njit = None + + +__all__ = [ + 'numba_jit' +] + + +def numba_jit(f=None, **kwargs): + if f is None: + return lambda f: (f if (njit is None) else njit(f, **kwargs)) + else: + if njit is None: + return f + else: + return njit(f) + diff --git a/examples/simulation/Brette_2007_COBA.py b/examples/simulation/Brette_2007_COBA.py index feda5f622..756ae225f 100644 --- a/examples/simulation/Brette_2007_COBA.py +++ b/examples/simulation/Brette_2007_COBA.py @@ -15,8 +15,8 @@ def __init__(self, scale=1.0, method='exp_auto'): pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) E = bp.dyn.LIF(num_exc, **pars, method=method) I = bp.dyn.LIF(num_inh, **pars, method=method) - E.num[:] = bp.math.random.randn(num_exc) * 2 - 55. - I.num[:] = bp.math.random.randn(num_inh) * 2 - 55. + E.V[:] = bp.math.random.randn(num_exc) * 2 - 55. + I.V[:] = bp.math.random.randn(num_inh) * 2 - 55. # synapses we = 0.6 / scale # excitatory synaptic weight (voltage) diff --git a/examples/training/Gauthier_2021_ngrc_double_scroll.py b/examples/training/Gauthier_2021_ngrc_double_scroll.py index d01b3a239..33863c0c4 100644 --- a/examples/training/Gauthier_2021_ngrc_double_scroll.py +++ b/examples/training/Gauthier_2021_ngrc_double_scroll.py @@ -132,7 +132,7 @@ def plot_double_scroll(ground_truth, predictions): outputs = trainer.predict(X_warmup) print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) trainer.fit([X_train, {'readout': dX_train}]) -plot_weights(di.weights.numpy(), di.bias.numpy(), r.comb_ids.numpy()) +plot_weights(di.Wff.numpy(), di.bias.numpy(), r.comb_ids.numpy()) # prediction model = bm.jit(model) diff --git a/examples/training/Gauthier_2021_ngrc_lorenz.py b/examples/training/Gauthier_2021_ngrc_lorenz.py index bb0fa81ce..66e5cc060 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz.py @@ -135,7 +135,7 @@ def plot_lorenz(ground_truth, predictions): outputs = trainer.predict(X_warmup) print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) trainer.fit([X_train, {'readout': dX_train}]) -plot_weights(di.weights.numpy(), di.bias.numpy(), r.comb_ids.numpy()) +plot_weights(di.Wff.numpy(), di.bias.numpy(), r.comb_ids.numpy()) # prediction model = bm.jit(model) diff --git a/examples/training/integrator_rnn.py b/examples/training/integrator_rnn.py index 13780f07a..5af37c2b2 100644 --- a/examples/training/integrator_rnn.py +++ b/examples/training/integrator_rnn.py @@ -65,7 +65,7 @@ def loss(predictions, targets, l2_reg=2e-4): plt.plot(trainer.train_losses.numpy()) plt.show() -model.init_state(1) +model.initialize(1) x, y = build_inputs_and_targets(batch_size=1) predicts = trainer.predict(x) diff --git a/requirements-dev.txt b/requirements-dev.txt index 670959ba3..bd3ca8783 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ -r requirements.txt +numba matplotlib>=3.4 jaxlib>=0.1.64 sympy>=1.6 diff --git a/requirements-doc.txt b/requirements-doc.txt index 450150053..9bf407507 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -5,6 +5,7 @@ jaxlib>=0.1.64 sympy>=1.6 scipy>=1.1.0 brainpylib +numba # document requirements pandoc diff --git a/requirements-win.txt b/requirements-win.txt index 6add3372f..7efaf71b2 100644 --- a/requirements-win.txt +++ b/requirements-win.txt @@ -1,5 +1,6 @@ numpy>=1.15 tqdm +numba matplotlib>=3.4 sympy>=1.6 scipy>=1.1.0