Skip to content

Commit

Permalink
Merge pull request #355 from chaoming0625/master
Browse files Browse the repository at this point in the history
Enable memory-efficient ``DSRunner``
  • Loading branch information
chaoming0625 authored Apr 7, 2023
2 parents 99f22b6 + 6f73232 commit 00c790a
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 190 deletions.
5 changes: 3 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
# DynamicalSystem base classes
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS)
NeuGroupNS as NeuGroupNS,
TwoEndConnNS as TwoEndConnNS,
)
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )
Expand Down Expand Up @@ -111,7 +113,6 @@

from . import running, testing
from ._src.visualization import (visualize as visualize)
from ._src.running.runner import (Runner as Runner)


# Part 7: Deprecations #
Expand Down
7 changes: 6 additions & 1 deletion brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def check_post_attrs(self, *attrs):
if not hasattr(self.post, attr):
raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')

def update(self, tdi, pre_spike=None):
def update(self, *args, **kwargs):
"""The function to specify the updating rule.
Assume any dynamical system depends on the shared variables (`sha`),
Expand Down Expand Up @@ -1024,6 +1024,11 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
return post_vs


class TwoEndConnNS(TwoEndConn):
"""Two-end connection without passing shared arguments."""
_pass_shared_args = False


class CondNeuGroup(NeuGroup, Container):
r"""Base class to model conductance-based neuron group.
Expand Down
86 changes: 60 additions & 26 deletions brainpy/_src/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from brainpy._src.dyn.base import DynamicalSystem
from brainpy._src.dyn.context import share
from brainpy._src.running.runner import Runner
from brainpy.check import is_float, serialize_kwargs
from brainpy.errors import RunningError, NoLongerSupportError
from brainpy.check import serialize_kwargs
from brainpy.errors import RunningError
from brainpy.types import ArrayType, Output, Monitor

__all__ = [
Expand Down Expand Up @@ -319,6 +319,7 @@ def __init__(
# jit
jit: Union[bool, Dict[str, bool]] = True,
dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None,
memory_efficient: bool = False,

# extra info
dt: Optional[float] = None,
Expand All @@ -342,10 +343,9 @@ def __init__(
numpy_mon_after_run=numpy_mon_after_run)

# t0 and i0
is_float(t0, 't0', allow_none=False, allow_int=True)
self.i0 = 0
self._t0 = t0
self.i0 = bm.Variable(jnp.asarray(1, dtype=bm.int_))
self.t0 = bm.Variable(jnp.asarray(t0, dtype=bm.float_))
self.t0 = t0
if data_first_axis is None:
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
assert data_first_axis in ['B', 'T']
Expand All @@ -371,6 +371,11 @@ def __init__(
# run function
self._f_predict_compiled = dict()

# monitors
self._memory_efficient = memory_efficient
if memory_efficient and not numpy_mon_after_run:
raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.')

def __repr__(self):
name = self.__class__.__name__
indent = " " * len(name) + ' '
Expand All @@ -382,8 +387,8 @@ def __repr__(self):

def reset_state(self):
"""Reset state of the ``DSRunner``."""
self.i0.value = jnp.zeros_like(self.i0.value)
self.t0.value = jnp.ones_like(self.t0.value) * self._t0
self.i0 = 0
self.t0 = self._t0

def predict(
self,
Expand Down Expand Up @@ -438,11 +443,12 @@ def predict(
"""

if inputs_are_batching is not None:
raise NoLongerSupportError(
raise warnings.warn(
f'''
`inputs_are_batching` is no longer supported.
The target mode of {self.target.mode} has already indicated the input should be batching.
'''
''',
UserWarning
)
if duration is None:
if inputs is None:
Expand All @@ -466,7 +472,7 @@ def predict(
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)
shared = tools.DotDict(i=jnp.arange(num_step, dtype=bm.int_))
shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_))
shared['t'] = shared['i'] * self.dt
shared['i'] += self.i0
shared['t'] += self.t0
Expand All @@ -486,7 +492,8 @@ def predict(
# running
if eval_time:
t0 = time.time()
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
with jax.disable_jit(not self.jit['predict']):
outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
if eval_time:
running_time = time.time() - t0

Expand All @@ -495,11 +502,16 @@ def predict(
self._pbar.close()

# post-running for monitors
hists['ts'] = shared['t'] + self.dt
if self.numpy_mon_after_run:
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
for key in hists.keys():
self.mon[key] = hists[key]
if self._memory_efficient:
self.mon['ts'] = shared['t'] + self.dt
for key in self.mon.var_names:
self.mon[key] = np.asarray(self.mon[key])
else:
hists['ts'] = shared['t'] + self.dt
if self.numpy_mon_after_run:
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
for key in hists.keys():
self.mon[key] = hists[key]
self.i0 += num_step
self.t0 += (num_step * self.dt if duration is None else duration)
return outputs if not eval_time else (running_time, outputs)
Expand Down Expand Up @@ -609,10 +621,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
raise ValueError(f'Number of time step is different across arrays in '
f'the provided "xs". We got {set(num_steps)}.')
return num_steps[0]

else:
raise ValueError

def _step_mon_on_cpu(self, args, transforms):
for key, val in args.items():
self.mon[key].append(val)

def _step_func_predict(self, shared_args, t, i, x):
# input step
shared = tools.DotDict(t=t, i=i, dt=self.dt)
Expand All @@ -633,7 +648,12 @@ def _step_func_predict(self, shared_args, t, i, x):
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
share.clear_shargs()
return out, mon

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
return out, None
else:
return out, mon

def _get_f_predict(self, shared_args: Dict = None):
if shared_args is None:
Expand All @@ -646,16 +666,30 @@ def _get_f_predict(self, shared_args: Dict = None):
dyn_vars.update(self.vars(level=0))
dyn_vars = dyn_vars.unique()

def run_func(all_inputs):
return bm.for_loop(partial(self._step_func_predict, shared_args),
all_inputs,
dyn_vars=dyn_vars,
jit=self.jit['predict'])
if self._memory_efficient:
_jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars)

def run_func(all_inputs):
outs = None
times, indices, xs = all_inputs
for i in range(times.shape[0]):
out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs))
if outs is None:
outs = tree_map(lambda a: [], out)
outs = tree_map(lambda a, o: o.append(a), out, outs)
outs = tree_map(lambda a: bm.as_jax(a), outs)
return outs, None

if self.jit['predict']:
self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars)
else:
self._f_predict_compiled[shared_kwargs_str] = run_func
@bm.jit(dyn_vars=dyn_vars)
def run_func(all_inputs):
return bm.for_loop(partial(self._step_func_predict, shared_args),
all_inputs,
dyn_vars=dyn_vars,
jit=self.jit['predict'])

self._f_predict_compiled[shared_kwargs_str] = run_func

return self._f_predict_compiled[shared_kwargs_str]

def __del__(self):
Expand Down
7 changes: 5 additions & 2 deletions brainpy/_src/initialize/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def variable(
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
return bm.Variable(init(size))
elif isinstance(batch_size_or_mode, bm.BatchingMode):
new_shape = size[:batch_axis] + (1,) + size[batch_axis:]
new_shape = size[:batch_axis] + (batch_size_or_mode.batch_size,) + size[batch_axis:]
return bm.Variable(init(new_shape), batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
return bm.Variable(init(size))
Expand All @@ -185,7 +185,10 @@ def variable(
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
return bm.Variable(init)
elif isinstance(batch_size_or_mode, bm.BatchingMode):
return bm.Variable(bm.expand_dims(init, axis=batch_axis), batch_axis=batch_axis)
return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
batch_size_or_mode.batch_size,
axis=batch_axis),
batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
return bm.Variable(init)
elif isinstance(batch_size_or_mode, int):
Expand Down
38 changes: 22 additions & 16 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,20 +490,23 @@ class training_environment(environment):
"""

def __init__(self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None):
def __init__(
self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
batch_size: int = 1,
):
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
mode=modes.TrainingMode())
mode=modes.TrainingMode(batch_size))


class batching_environment(environment):
Expand All @@ -519,20 +522,23 @@ class batching_environment(environment):
"""

def __init__(self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None):
def __init__(
self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None,
batch_size: int = 1,
):
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
mode=modes.BatchingMode())
mode=modes.BatchingMode(batch_size))


def enable_x64():
Expand Down
10 changes: 7 additions & 3 deletions brainpy/_src/math/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __eq__(self, other: 'Mode'):
def is_a(self, mode: type):
assert isinstance(mode, type)
return self.__class__ == mode

def is_parent_of(self, *modes):
cls = self.__class__
for smode in modes:
Expand Down Expand Up @@ -58,7 +58,12 @@ class BatchingMode(Mode):
:py:class:`~.NonBatchingMode` is usually used in models of model trainings.
"""
pass

def __init__(self, batch_size: int = 1):
self.batch_size = batch_size

def __repr__(self):
return f'{self.__class__.__name__}(batch_size={self.batch_size})'


class TrainingMode(BatchingMode):
Expand All @@ -74,4 +79,3 @@ class TrainingMode(BatchingMode):

training_mode = TrainingMode()
'''Default instance of the training computation mode.'''

Loading

0 comments on commit 00c790a

Please sign in to comment.