From 31dc8794edf1d5bf89b48ce16d73a0f9bba2757d Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 6 Apr 2023 14:58:26 +0800 Subject: [PATCH 1/4] support `batch_size` for `brainpy.math.BatchingMode`. This enables the model initialization has the wanted batch size. --- brainpy/_src/initialize/generic.py | 7 ++++-- brainpy/_src/math/environment.py | 38 +++++++++++++++++------------- brainpy/_src/math/modes.py | 10 +++++--- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/brainpy/_src/initialize/generic.py b/brainpy/_src/initialize/generic.py index a265f4f11..3ca400e30 100644 --- a/brainpy/_src/initialize/generic.py +++ b/brainpy/_src/initialize/generic.py @@ -168,7 +168,7 @@ def variable( if isinstance(batch_size_or_mode, bm.NonBatchingMode): return bm.Variable(init(size)) elif isinstance(batch_size_or_mode, bm.BatchingMode): - new_shape = size[:batch_axis] + (1,) + size[batch_axis:] + new_shape = size[:batch_axis] + (batch_size_or_mode.batch_size,) + size[batch_axis:] return bm.Variable(init(new_shape), batch_axis=batch_axis) elif batch_size_or_mode in (None, False): return bm.Variable(init(size)) @@ -185,7 +185,10 @@ def variable( if isinstance(batch_size_or_mode, bm.NonBatchingMode): return bm.Variable(init) elif isinstance(batch_size_or_mode, bm.BatchingMode): - return bm.Variable(bm.expand_dims(init, axis=batch_axis), batch_axis=batch_axis) + return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), + batch_size_or_mode.batch_size, + axis=batch_axis), + batch_axis=batch_axis) elif batch_size_or_mode in (None, False): return bm.Variable(init) elif isinstance(batch_size_or_mode, int): diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index a51a5c35a..0176d0d08 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -490,20 +490,23 @@ class training_environment(environment): """ - def __init__(self, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None): + def __init__( + self, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + batch_size: int = 1, + ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, - mode=modes.TrainingMode()) + mode=modes.TrainingMode(batch_size)) class batching_environment(environment): @@ -519,20 +522,23 @@ class batching_environment(environment): """ - def __init__(self, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None): + def __init__( + self, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + batch_size: int = 1, + ): super().__init__(dt=dt, x64=x64, complex_=complex_, float_=float_, int_=int_, bool_=bool_, - mode=modes.BatchingMode()) + mode=modes.BatchingMode(batch_size)) def enable_x64(): diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py index e619b2172..f3c126a99 100644 --- a/brainpy/_src/math/modes.py +++ b/brainpy/_src/math/modes.py @@ -26,7 +26,7 @@ def __eq__(self, other: 'Mode'): def is_a(self, mode: type): assert isinstance(mode, type) return self.__class__ == mode - + def is_parent_of(self, *modes): cls = self.__class__ for smode in modes: @@ -58,7 +58,12 @@ class BatchingMode(Mode): :py:class:`~.NonBatchingMode` is usually used in models of model trainings. """ - pass + + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size + + def __repr__(self): + return f'{self.__class__.__name__}(batch_size={self.batch_size})' class TrainingMode(BatchingMode): @@ -74,4 +79,3 @@ class TrainingMode(BatchingMode): training_mode = TrainingMode() '''Default instance of the training computation mode.''' - From d336649615c18454596b75960572d9a77a734fe4 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 7 Apr 2023 10:39:23 +0800 Subject: [PATCH 2/4] add `TwoEndConnNS` --- brainpy/__init__.py | 4 +++- brainpy/_src/dyn/base.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 765ba90bc..6dbebb8b1 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -81,7 +81,9 @@ from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,) # DynamicalSystem base classes from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS, - NeuGroupNS as NeuGroupNS) + NeuGroupNS as NeuGroupNS, + TwoEndConnNS as TwoEndConnNS, + ) from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS, SynSTPNS as SynSTPNS, SynConnNS as SynConnNS, ) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 01300f141..f104012fd 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -1024,6 +1024,11 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): return post_vs +class TwoEndConnNS(TwoEndConn): + """Two-end connection without passing shared arguments.""" + _pass_shared_args = False + + class CondNeuGroup(NeuGroup, Container): r"""Base class to model conductance-based neuron group. From 544e749224f5fb3b1a312cce2fc0d3cb13ce81c8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 7 Apr 2023 10:39:36 +0800 Subject: [PATCH 3/4] update jit-conn-ops doc --- brainpy/_src/math/operators/jitconn_ops.py | 289 +++++++++++---------- 1 file changed, 149 insertions(+), 140 deletions(-) diff --git a/brainpy/_src/math/operators/jitconn_ops.py b/brainpy/_src/math/operators/jitconn_ops.py index 022466300..fcfc96158 100644 --- a/brainpy/_src/math/operators/jitconn_ops.py +++ b/brainpy/_src/math/operators/jitconn_ops.py @@ -23,70 +23,7 @@ ] -def event_matvec_prob_conn_homo_weight( - events: jnp.ndarray, - weight: float, - *, - conn_prob: float, - shape: Tuple[int, int], - seed: Optional[int] = None, - transpose: bool = False, - outdim_parallel: bool = True, -) -> jnp.ndarray: - bl = tools.import_brainpylib() - return bl.jitconn_ops.event_matvec_prob_conn_homo_weight(events, weight, - conn_prob=conn_prob, - shape=shape, - seed=seed, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_matvec_prob_conn_uniform_weight( - events: jnp.ndarray, - *, - w_low: float, - w_high: float, - conn_prob: float, - shape: Tuple[int, int], - seed: Optional[int] = None, - transpose: bool = False, - outdim_parallel: bool = True, -) -> jnp.ndarray: - bl = tools.import_brainpylib() - return bl.jitconn_ops.event_matvec_prob_conn_uniform_weight(events, - w_low=w_low, - w_high=w_high, - conn_prob=conn_prob, - shape=shape, - seed=seed, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def event_matvec_prob_conn_normal_weight( - events: jnp.ndarray, - *, - w_mu: float, - w_sigma: float, - conn_prob: float, - shape: Tuple[int, int], - seed: Optional[int] = None, - transpose: bool = False, - outdim_parallel: bool = True, -) -> jnp.ndarray: - bl = tools.import_brainpylib() - return bl.jitconn_ops.event_matvec_prob_conn_normal_weight(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=conn_prob, - shape=shape, - seed=seed, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def matmat_prob_conn_homo_weight( +def matvec_prob_conn_homo_weight( vector: jnp.ndarray, weight: float, *, @@ -95,9 +32,10 @@ def matmat_prob_conn_homo_weight( seed: Optional[int] = None, transpose: bool = False, outdim_parallel: bool = True, + version: str = 'v2' ) -> jnp.ndarray: - r"""Perform the :math:`Y=X@M` operation, where :math:`X`, :math:`Y` and :math:`M` are matrices, - and :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations on CPU and GPU devices. @@ -109,7 +47,7 @@ def matmat_prob_conn_homo_weight( In this operation, :math:`M` is the random matrix with a connection probability `conn_prob`, and at each connection the value is the same scalar `weight`. - When ``transpose=True``, we perform an operation of :math:`Y=X@M^T`. + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. .. note:: @@ -142,27 +80,30 @@ def matmat_prob_conn_homo_weight( Returns ------- out: Array, ndarray - The output of :math:`Y = X @ M`. + The output of :math:`y = M @ v`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matmat_prob_conn_homo_weight(vector, + return bl.jitconn_ops.matvec_prob_conn_homo_weight(vector, weight, conn_prob=conn_prob, shape=shape, seed=seed, transpose=transpose, - outdim_parallel=outdim_parallel) + outdim_parallel=outdim_parallel, + version=version) -def matmat_prob_conn_uniform_weight( - matrix: jax.Array, +def matvec_prob_conn_uniform_weight( + vector: jnp.ndarray, *, w_low: float, w_high: float, conn_prob: float, shape: Tuple[int, int], seed: Optional[int] = None, - version: str = 'v1' + transpose: bool = False, + outdim_parallel: bool = True, + version: str = 'v2' ) -> jnp.ndarray: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. @@ -190,8 +131,8 @@ def matmat_prob_conn_uniform_weight( Parameters ---------- - matrix: Array - The matrix :math:`X`. + vector: Array, ndarray + The vector. w_low: float Lower boundary of the output interval. w_high: float @@ -202,6 +143,12 @@ def matmat_prob_conn_uniform_weight( The matrix shape. seed: int The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. Returns ------- @@ -209,27 +156,31 @@ def matmat_prob_conn_uniform_weight( The output of :math:`y = M @ v`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matmat_prob_conn_uniform_weight(matrix, + return bl.jitconn_ops.matvec_prob_conn_uniform_weight(vector, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, + transpose=transpose, + outdim_parallel=outdim_parallel, version=version) -def matmat_prob_conn_normal_weight( - matrix: jax.Array, +def matvec_prob_conn_normal_weight( + vector: jnp.ndarray, *, w_mu: float, w_sigma: float, conn_prob: float, shape: Tuple[int, int], seed: Optional[int] = None, - version: str = 'v1' -) -> jax.Array: - r"""Perform the :math:`Y=X@M` operation, where :math:`X`, :math:`Y` and :math:`M` are matrices, - and :math:`M` is just-in-time randomly generated with a normal distribution for its value. + transpose: bool = False, + outdim_parallel: bool = True, + version: str = 'v2' +) -> jnp.ndarray: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations on CPU and GPU devices. @@ -241,10 +192,21 @@ def matmat_prob_conn_normal_weight( In this operation, :math:`M` is the random matrix with a connection probability `conn_prob`, and at each connection the value is the same scalar `weight`. + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + Parameters ---------- - matrix: Array - The matrix :math:`X`. + vector: Array, ndarray + The vector. w_mu: float Mean (centre) of the distribution. w_sigma: float @@ -255,6 +217,12 @@ def matmat_prob_conn_normal_weight( The matrix shape. seed: int The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. Returns ------- @@ -262,16 +230,90 @@ def matmat_prob_conn_normal_weight( The output of :math:`y = M @ v`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matmat_prob_conn_normal_weight(matrix, + return bl.jitconn_ops.matvec_prob_conn_normal_weight(vector, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, + transpose=transpose, + outdim_parallel=outdim_parallel, version=version) -def matvec_prob_conn_homo_weight( +def event_matvec_prob_conn_homo_weight( + events: jnp.ndarray, + weight: float, + *, + conn_prob: float, + shape: Tuple[int, int], + seed: Optional[int] = None, + transpose: bool = False, + outdim_parallel: bool = True, +) -> jnp.ndarray: + bl = tools.import_brainpylib() + return bl.jitconn_ops.event_matvec_prob_conn_homo_weight(events, weight, + conn_prob=conn_prob, + shape=shape, + seed=seed, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_matvec_prob_conn_homo_weight.__doc__ = matvec_prob_conn_homo_weight.__doc__ + + +def event_matvec_prob_conn_uniform_weight( + events: jnp.ndarray, + *, + w_low: float, + w_high: float, + conn_prob: float, + shape: Tuple[int, int], + seed: Optional[int] = None, + transpose: bool = False, + outdim_parallel: bool = True, +) -> jnp.ndarray: + bl = tools.import_brainpylib() + return bl.jitconn_ops.event_matvec_prob_conn_uniform_weight(events, + w_low=w_low, + w_high=w_high, + conn_prob=conn_prob, + shape=shape, + seed=seed, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_matvec_prob_conn_uniform_weight.__doc__ = matvec_prob_conn_uniform_weight.__doc__ + + +def event_matvec_prob_conn_normal_weight( + events: jnp.ndarray, + *, + w_mu: float, + w_sigma: float, + conn_prob: float, + shape: Tuple[int, int], + seed: Optional[int] = None, + transpose: bool = False, + outdim_parallel: bool = True, +) -> jnp.ndarray: + bl = tools.import_brainpylib() + return bl.jitconn_ops.event_matvec_prob_conn_normal_weight(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=conn_prob, + shape=shape, + seed=seed, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_matvec_prob_conn_normal_weight.__doc__ = matvec_prob_conn_normal_weight.__doc__ + + +def matmat_prob_conn_homo_weight( vector: jnp.ndarray, weight: float, *, @@ -280,10 +322,9 @@ def matvec_prob_conn_homo_weight( seed: Optional[int] = None, transpose: bool = False, outdim_parallel: bool = True, - version: str = 'v2' ) -> jnp.ndarray: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + r"""Perform the :math:`Y=X@M` operation, where :math:`X`, :math:`Y` and :math:`M` are matrices, + and :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations on CPU and GPU devices. @@ -295,7 +336,7 @@ def matvec_prob_conn_homo_weight( In this operation, :math:`M` is the random matrix with a connection probability `conn_prob`, and at each connection the value is the same scalar `weight`. - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + When ``transpose=True``, we perform an operation of :math:`Y=X@M^T`. .. note:: @@ -328,30 +369,27 @@ def matvec_prob_conn_homo_weight( Returns ------- out: Array, ndarray - The output of :math:`y = M @ v`. + The output of :math:`Y = X @ M`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matvec_prob_conn_homo_weight(vector, + return bl.jitconn_ops.matmat_prob_conn_homo_weight(vector, weight, conn_prob=conn_prob, shape=shape, seed=seed, transpose=transpose, - outdim_parallel=outdim_parallel, - version=version) + outdim_parallel=outdim_parallel) -def matvec_prob_conn_uniform_weight( - vector: jnp.ndarray, +def matmat_prob_conn_uniform_weight( + matrix: jax.Array, *, w_low: float, w_high: float, conn_prob: float, shape: Tuple[int, int], seed: Optional[int] = None, - transpose: bool = False, - outdim_parallel: bool = True, - version: str = 'v2' + version: str = 'v1' ) -> jnp.ndarray: r"""Perform the :math:`y=M@v` operation, where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. @@ -379,8 +417,8 @@ def matvec_prob_conn_uniform_weight( Parameters ---------- - vector: Array, ndarray - The vector. + matrix: Array + The matrix :math:`X`. w_low: float Lower boundary of the output interval. w_high: float @@ -391,12 +429,6 @@ def matvec_prob_conn_uniform_weight( The matrix shape. seed: int The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. Returns ------- @@ -404,31 +436,27 @@ def matvec_prob_conn_uniform_weight( The output of :math:`y = M @ v`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matvec_prob_conn_uniform_weight(vector, + return bl.jitconn_ops.matmat_prob_conn_uniform_weight(matrix, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, - transpose=transpose, - outdim_parallel=outdim_parallel, version=version) -def matvec_prob_conn_normal_weight( - vector: jnp.ndarray, +def matmat_prob_conn_normal_weight( + matrix: jax.Array, *, w_mu: float, w_sigma: float, conn_prob: float, shape: Tuple[int, int], seed: Optional[int] = None, - transpose: bool = False, - outdim_parallel: bool = True, - version: str = 'v2' -) -> jnp.ndarray: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + version: str = 'v1' +) -> jax.Array: + r"""Perform the :math:`Y=X@M` operation, where :math:`X`, :math:`Y` and :math:`M` are matrices, + and :math:`M` is just-in-time randomly generated with a normal distribution for its value. This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations on CPU and GPU devices. @@ -440,21 +468,10 @@ def matvec_prob_conn_normal_weight( In this operation, :math:`M` is the random matrix with a connection probability `conn_prob`, and at each connection the value is the same scalar `weight`. - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - Parameters ---------- - vector: Array, ndarray - The vector. + matrix: Array + The matrix :math:`X`. w_mu: float Mean (centre) of the distribution. w_sigma: float @@ -465,12 +482,6 @@ def matvec_prob_conn_normal_weight( The matrix shape. seed: int The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. Returns ------- @@ -478,12 +489,10 @@ def matvec_prob_conn_normal_weight( The output of :math:`y = M @ v`. """ bl = tools.import_brainpylib() - return bl.jitconn_ops.matvec_prob_conn_normal_weight(vector, + return bl.jitconn_ops.matmat_prob_conn_normal_weight(matrix, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, - transpose=transpose, - outdim_parallel=outdim_parallel, version=version) From 6f732321bd577d5ccfdd43c25f2b5494d249f691 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 7 Apr 2023 20:00:38 +0800 Subject: [PATCH 4/4] enable monitoring GPU models on CPU when setting `DSRunner(..., memory_efficient=True)` --- brainpy/__init__.py | 1 - brainpy/_src/dyn/base.py | 2 +- brainpy/_src/dyn/runners.py | 86 ++++++++++++++++++++++++++----------- 3 files changed, 61 insertions(+), 28 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 6dbebb8b1..9f277ca72 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -113,7 +113,6 @@ from . import running, testing from ._src.visualization import (visualize as visualize) -from ._src.running.runner import (Runner as Runner) # Part 7: Deprecations # diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index f104012fd..f273d6973 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -767,7 +767,7 @@ def check_post_attrs(self, *attrs): if not hasattr(self.post, attr): raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".') - def update(self, tdi, pre_spike=None): + def update(self, *args, **kwargs): """The function to specify the updating rule. Assume any dynamical system depends on the shared variables (`sha`), diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index cedc1ca76..24b281f47 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -17,8 +17,8 @@ from brainpy._src.dyn.base import DynamicalSystem from brainpy._src.dyn.context import share from brainpy._src.running.runner import Runner -from brainpy.check import is_float, serialize_kwargs -from brainpy.errors import RunningError, NoLongerSupportError +from brainpy.check import serialize_kwargs +from brainpy.errors import RunningError from brainpy.types import ArrayType, Output, Monitor __all__ = [ @@ -319,6 +319,7 @@ def __init__( # jit jit: Union[bool, Dict[str, bool]] = True, dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None, + memory_efficient: bool = False, # extra info dt: Optional[float] = None, @@ -342,10 +343,9 @@ def __init__( numpy_mon_after_run=numpy_mon_after_run) # t0 and i0 - is_float(t0, 't0', allow_none=False, allow_int=True) + self.i0 = 0 self._t0 = t0 - self.i0 = bm.Variable(jnp.asarray(1, dtype=bm.int_)) - self.t0 = bm.Variable(jnp.asarray(t0, dtype=bm.float_)) + self.t0 = t0 if data_first_axis is None: data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T' assert data_first_axis in ['B', 'T'] @@ -371,6 +371,11 @@ def __init__( # run function self._f_predict_compiled = dict() + # monitors + self._memory_efficient = memory_efficient + if memory_efficient and not numpy_mon_after_run: + raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.') + def __repr__(self): name = self.__class__.__name__ indent = " " * len(name) + ' ' @@ -382,8 +387,8 @@ def __repr__(self): def reset_state(self): """Reset state of the ``DSRunner``.""" - self.i0.value = jnp.zeros_like(self.i0.value) - self.t0.value = jnp.ones_like(self.t0.value) * self._t0 + self.i0 = 0 + self.t0 = self._t0 def predict( self, @@ -438,11 +443,12 @@ def predict( """ if inputs_are_batching is not None: - raise NoLongerSupportError( + raise warnings.warn( f''' `inputs_are_batching` is no longer supported. The target mode of {self.target.mode} has already indicated the input should be batching. - ''' + ''', + UserWarning ) if duration is None: if inputs is None: @@ -466,7 +472,7 @@ def predict( if shared_args is None: shared_args = dict() shared_args['fit'] = shared_args.get('fit', False) - shared = tools.DotDict(i=jnp.arange(num_step, dtype=bm.int_)) + shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_)) shared['t'] = shared['i'] * self.dt shared['i'] += self.i0 shared['t'] += self.t0 @@ -486,7 +492,8 @@ def predict( # running if eval_time: t0 = time.time() - outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args) + with jax.disable_jit(not self.jit['predict']): + outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args) if eval_time: running_time = time.time() - t0 @@ -495,11 +502,16 @@ def predict( self._pbar.close() # post-running for monitors - hists['ts'] = shared['t'] + self.dt - if self.numpy_mon_after_run: - hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array)) - for key in hists.keys(): - self.mon[key] = hists[key] + if self._memory_efficient: + self.mon['ts'] = shared['t'] + self.dt + for key in self.mon.var_names: + self.mon[key] = np.asarray(self.mon[key]) + else: + hists['ts'] = shared['t'] + self.dt + if self.numpy_mon_after_run: + hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array)) + for key in hists.keys(): + self.mon[key] = hists[key] self.i0 += num_step self.t0 += (num_step * self.dt if duration is None else duration) return outputs if not eval_time else (running_time, outputs) @@ -609,10 +621,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int: raise ValueError(f'Number of time step is different across arrays in ' f'the provided "xs". We got {set(num_steps)}.') return num_steps[0] - else: raise ValueError + def _step_mon_on_cpu(self, args, transforms): + for key, val in args.items(): + self.mon[key].append(val) + def _step_func_predict(self, shared_args, t, i, x): # input step shared = tools.DotDict(t=t, i=i, dt=self.dt) @@ -633,7 +648,12 @@ def _step_func_predict(self, shared_args, t, i, x): if self.progress_bar: id_tap(lambda *arg: self._pbar.update(), ()) share.clear_shargs() - return out, mon + + if self._memory_efficient: + id_tap(self._step_mon_on_cpu, mon) + return out, None + else: + return out, mon def _get_f_predict(self, shared_args: Dict = None): if shared_args is None: @@ -646,16 +666,30 @@ def _get_f_predict(self, shared_args: Dict = None): dyn_vars.update(self.vars(level=0)) dyn_vars = dyn_vars.unique() - def run_func(all_inputs): - return bm.for_loop(partial(self._step_func_predict, shared_args), - all_inputs, - dyn_vars=dyn_vars, - jit=self.jit['predict']) + if self._memory_efficient: + _jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars) + + def run_func(all_inputs): + outs = None + times, indices, xs = all_inputs + for i in range(times.shape[0]): + out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs)) + if outs is None: + outs = tree_map(lambda a: [], out) + outs = tree_map(lambda a, o: o.append(a), out, outs) + outs = tree_map(lambda a: bm.as_jax(a), outs) + return outs, None - if self.jit['predict']: - self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars) else: - self._f_predict_compiled[shared_kwargs_str] = run_func + @bm.jit(dyn_vars=dyn_vars) + def run_func(all_inputs): + return bm.for_loop(partial(self._step_func_predict, shared_args), + all_inputs, + dyn_vars=dyn_vars, + jit=self.jit['predict']) + + self._f_predict_compiled[shared_kwargs_str] = run_func + return self._f_predict_compiled[shared_kwargs_str] def __del__(self):