Skip to content

Commit

Permalink
Support initializing a Variable by data shape (#265)
Browse files Browse the repository at this point in the history
Support initializing a Variable by data shape
  • Loading branch information
chaoming0625 authored Oct 2, 2022
2 parents b8691ae + 359dcbe commit b4ef718
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 78 deletions.
4 changes: 2 additions & 2 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,8 @@ def __init__(

def __repr__(self):
names = self.__class__.__name__
return (f'{names}(name={self.name}, mode={self.mode}, '
f'{" " * len(names)} pre={self.pre}, '
return (f'{names}(name={self.name}, mode={self.mode}, \n'
f'{" " * len(names)} pre={self.pre}, \n'
f'{" " * len(names)} post={self.post})')

def check_pre_attrs(self, *attrs):
Expand Down
4 changes: 0 additions & 4 deletions brainpy/dyn/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ class Dropout(DynamicalSystem):
In training, to compensate for the fraction of input values dropped (`rate`),
all surviving values are multiplied by `1 / (1 - rate)`.
The parameter `shared_axes` allows to specify a list of axes on which
the mask will be shared: we will use size 1 on those axes for dropout mask
and broadcast it. Sharing reduces randomness, but can save memory.
This layer is active only during training (`mode='train'`). In other
circumstances it is a no-op.
Expand Down
4 changes: 2 additions & 2 deletions brainpy/dyn/synapses/gap_junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
conn=conn,
name=name)
# checking
self.check_pre_attrs('V', 'spike')
self.check_post_attrs('V', 'input', 'spike')
self.check_pre_attrs('V')
self.check_post_attrs('V', 'input')

# assert isinstance(self.output, _NullSynOut)
# assert isinstance(self.stp, _NullSynSTP)
Expand Down
42 changes: 37 additions & 5 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,42 @@ def __jax_array__(self):

class Variable(JaxArray):
"""The pointer to specify the dynamical variable.
Initializing an instance of ``Variable`` by two ways:
>>> import brainpy.math as bm
>>> # 1. init a Variable by the concreate data
>>> v1 = bm.Variable(bm.zeros(10))
>>> # 2. init a Variable by the data shape
>>> v2 = bm.Variable(10)
Note that when initializing a `Variable` by the data shape,
all values in this `Variable` will be initialized as zeros.
Parameters
----------
value_or_size: Shape, Array
The value or the size of the value.
dtype:
The type of the data.
batch_axis: optional, int
The batch axis.
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
def __init__(
self,
value_or_size,
dtype=None,
batch_axis: int = None
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]):
value = jnp.zeros(value_or_size, dtype=dtype)
else:
value = value_or_size

super(Variable, self).__init__(value, dtype=dtype)

# check batch axis
Expand Down Expand Up @@ -1464,17 +1496,17 @@ class TrainVar(Variable):
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


class Parameter(Variable):
"""The pointer to specify the parameter.
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


register_pytree_node(JaxArray,
Expand Down
11 changes: 11 additions & 0 deletions brainpy/math/tests/test_jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,14 @@ def test_none(self):
ee = a + e


class TestVariable(unittest.TestCase):
def test_variable_init(self):
self.assertTrue(
bm.array_equal(bm.Variable(bm.zeros(10)),
bm.Variable(10))
)
bm.random.seed(123)
self.assertTrue(
not bm.array_equal(bm.Variable(bm.random.rand(10)),
bm.Variable(10))
)
101 changes: 101 additions & 0 deletions examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-


""""
Implementation of the paper:
- Fazli, Mehran, and Richard Bertram. "Network Properties of Electrically
Coupled Bursting Pituitary Cells." Frontiers in Endocrinology 13 (2022).
"""

import brainpy as bp
import brainpy.math as bm


class PituitaryCell(bp.NeuGroup):
def __init__(self, size, name=None):
super(PituitaryCell, self).__init__(size, name=name)

# parameter values
self.vn = -5
self.kc = 0.12
self.ff = 0.005
self.vca = 60
self.vk = -75
self.vl = -50.0
self.gk = 2.5
self.cm = 5
self.gbk = 1
self.gca = 2.1
self.gsk = 2
self.vm = -20
self.vb = -5
self.sn = 10
self.sm = 12
self.sbk = 2
self.taun = 30
self.taubk = 5
self.ks = 0.4
self.alpha = 0.0015
self.gl = 0.2

# variables
self.V = bm.Variable(bm.random.random(self.num) * -90 + 20)
self.n = bm.Variable(bm.random.random(self.num) / 2)
self.b = bm.Variable(bm.random.random(self.num) / 2)
self.c = bm.Variable(bm.random.random(self.num))
self.input = bm.Variable(self.num)

# integrators
self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler')

def dn(self, n, t, V):
ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn))
return (ninf - n) / self.taun

def db(self, b, t, V):
bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk))
return (bkinf - b) / self.taubk

def dc(self, c, t, V):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
ica = self.gca * minf * (V - self.vca)
return -self.ff * (self.alpha * ica + self.kc * c)

def dV(self, V, t, n, b, c):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
cinf = c ** 2 / (c ** 2 + self.ks * self.ks)
ica = self.gca * minf * (V - self.vca)
isk = self.gsk * cinf * (V - self.vk)
ibk = self.gbk * b * (V - self.vk)
ikdr = self.gk * n * (V - self.vk)
il = self.gl * (V - self.vl)
return -(ica + isk + ibk + ikdr + il + self.input) / self.cm

def update(self, tdi, x=None):
V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt)
self.V.value = V
self.n.value = n
self.c.value = c
self.b.value = b

def clear_input(self):
self.input.value = bm.zeros_like(self.input)


class PituitaryNetwork(bp.Network):
def __init__(self, num, gc):
super(PituitaryNetwork, self).__init__()

self.N = PituitaryCell(num)
self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc)


if __name__ == '__main__':
net = PituitaryNetwork(2, 0.002)
runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5)
runner.run(10 * 1e3)

fig, gs = bp.visualize.get_figure(1, 1, 6, 10)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True)
3 changes: 2 additions & 1 deletion extensions/brainpylib/atomic_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def _atomic_sum_translation(c, values, pre_ids, post_ids, *, post_num, platform=
shape_with_layout=x_shape(np.dtype(values_dtype), (post_num,), (0,)),
)
elif platform == 'gpu':
if gpu_ops is None: raise ValueError('Cannot find compiled gpu wheels.')
if gpu_ops is None:
raise ValueError('Cannot find compiled gpu wheels.')

opaque = gpu_ops.build_atomic_sum_descriptor(conn_size, post_num)
if values_dim[0] != 1:
Expand Down
60 changes: 30 additions & 30 deletions extensions/brainpylib/event_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

from functools import partial

from typing import Union, Tuple
import jax.numpy as jnp
import numpy as np
from jax import core
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.interpreters import xla, batching
from jax.lax import scan
from jax.lib import xla_client

from .utils import GPUOperatorNotFound

try:
from . import gpu_ops
except ImportError:
Expand All @@ -26,7 +29,10 @@
_event_sum_prim = core.Primitive("event_sum")


def event_sum(events, pre2post, post_num, values):
def event_sum(events: jnp.ndarray,
pre2post: Tuple[jnp.ndarray, jnp.ndarray],
post_num: int,
values: Union[float, jnp.ndarray]):
# events
if events.dtype != jnp.bool_:
raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}')
Expand All @@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values):
if indices.dtype != indptr.dtype:
raise ValueError(f"The dtype of pre2post[0] must be equal to that of pre2post[1], "
f"while we got {(indices.dtype, indptr.dtype)}")
if indices.dtype not in [jnp.uint32, jnp.uint64]:
raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}')
if indices.dtype not in [jnp.uint32, jnp.uint64, jnp.int32, jnp.int64]:
raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}')

# output value
values = jnp.asarray([values])
if values.dtype not in [jnp.float32, jnp.float64]:
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.')
if values.size not in [1, indices.size]:
dtype = values.dtype if isinstance(values, jnp.ndarray) else dtypes.canonicalize_dtype(type(values))
if dtype not in [jnp.float32, jnp.float64]:
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {dtype}.')
if np.size(values) not in [1, indices.size]:
raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), '
f'while we got {values.size} != 1 != {indices.size}')
values = values.flatten()
f'while we got {np.size(values)} != 1 != {indices.size}')
# bind operator
return _event_sum_prim.bind(events, indices, indptr, values, post_num=post_num)

Expand All @@ -58,34 +63,27 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num):
return ShapedArray(dtype=values.dtype, shape=(post_num,))


_event_sum_prim.def_abstract_eval(_event_sum_abstract)
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))


def _event_sum_translation(c, events, indices, indptr, values, *, post_num, platform="cpu"):
# The pre/post shape
# The shape of pre/post
pre_size = np.array(c.get_shape(events).dimensions()[0], dtype=np.uint32)
_pre_shape = x_shape(np.dtype(np.uint32), (), ())
_post_shape = x_shape(np.dtype(np.uint32), (), ())

# The indices shape
indices_shape = c.get_shape(indices)
Itype = indices_shape.element_type()
assert Itype in [np.uint32, np.uint64]

# The value shape
values_shape = c.get_shape(values)
Ftype = values_shape.element_type()
assert Ftype in [np.float32, np.float64]
values_dim = values_shape.dimensions()

# We dispatch a different call depending on the dtype
f_type = b'_f32' if Ftype == np.float32 else b'_f64'
i_type = b'_i32' if Itype == np.uint32 else b'_i64'
f_type = b'_f32' if Ftype in np.float32 else b'_f64'
i_type = b'_i32' if Itype in [np.uint32, np.int32] else b'_i64'

# And then the following is what changes between the GPU and CPU
if platform == "cpu":
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
v_type = b'_event_sum_homo' if len(values_dim) == 0 else b'_event_sum_heter'
return x_ops.CustomCallWithLayout(
c,
platform.encode() + v_type + f_type + i_type,
Expand All @@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
c.get_shape(values)),
shape_with_layout=x_shape(np.dtype(Ftype), (post_num,), (0,)),
)

# GPU platform
elif platform == 'gpu':
if gpu_ops is None:
raise ValueError('Cannot find compiled gpu wheels.')
raise GPUOperatorNotFound('event_sum')

v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_num)
return x_ops.CustomCallWithLayout(
Expand All @@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
raise ValueError("Unsupported platform, we only support 'cpu' or 'gpu'")


xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")


def _event_sum_batch(args, axes):
def _event_sum_batch(args, axes, *, post_num):
batch_axes, batch_args, non_batch_args = [], {}, {}
for ax_i, ax in enumerate(axes):
if ax is None:
Expand All @@ -143,19 +140,22 @@ def _event_sum_batch(args, axes):
def f(_, x):
pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
for i in range(len(axes))])
return 0, _event_sum_prim.bind(*pars)
return 0, _event_sum_prim.bind(*pars, post_num=post_num)

_, outs = scan(f, 0, batch_args)
return outs, 0


_event_sum_prim.def_abstract_eval(_event_sum_abstract)
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
batching.primitive_batchers[_event_sum_prim] = _event_sum_batch

xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")

# ---------------------------
# event sum kernel 2
# ---------------------------


_event_sum2_prim = core.Primitive("event_sum2")


Expand Down
18 changes: 18 additions & 0 deletions extensions/brainpylib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-


__all__ = [
'GPUOperatorNotFound',
]


class GPUOperatorNotFound(Exception):
def __init__(self, name):
super(GPUOperatorNotFound, self).__init__(f'''
GPU operator for "{name}" does not found.
Please compile brainpylib GPU operators with the guidance in the following link:
https://brainpy.readthedocs.io/en/latest/tutorial_advanced/compile_brainpylib.html
''')

Loading

0 comments on commit b4ef718

Please sign in to comment.