diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index 80361af62..87263eae3 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -343,8 +343,6 @@ def f_loss(): grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True) optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) - dyn_vars = optimizer.vars() + (fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) - dyn_vars = dyn_vars.unique() def train(idx): gradients, loss = grad_f() @@ -353,7 +351,7 @@ def train(idx): return loss def batch_train(start_i, n_batch): - return bm.for_loop(train, bm.arange(start_i, start_i + n_batch), dyn_vars=dyn_vars) + return bm.for_loop(train, bm.arange(start_i, start_i + n_batch)) # Run the optimization if self.verbose: diff --git a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py index a9f76f7ef..78ab3f44c 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py @@ -4,6 +4,7 @@ from functools import partial import numpy as np +import jax from jax import numpy as jnp from jax import vmap from jax.scipy.optimize import minimize @@ -274,13 +275,13 @@ def F_fx(self): f = partial(f, **(self.pars_update + self.fixed_vars)) f = utils.f_without_jaxarray_return(f) f = utils.remove_return_shape(f) - self.analyzed_results[C.F_fx] = bm.jit(f, device=self.jit_device) + self.analyzed_results[C.F_fx] = jax.jit(f, device=self.jit_device) return self.analyzed_results[C.F_fx] @property def F_vmap_fx(self): if C.F_vmap_fx not in self.analyzed_results: - self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device) + self.analyzed_results[C.F_vmap_fx] = jax.jit(vmap(self.F_fx), device=self.jit_device) return self.analyzed_results[C.F_vmap_fx] @property @@ -288,7 +289,7 @@ def F_dfxdx(self): """The function to evaluate :math:`\frac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" if C.F_dfxdx not in self.analyzed_results: dfx = bm.vector_grad(self.F_fx, argnums=0) - self.analyzed_results[C.F_dfxdx] = bm.jit(dfx, device=self.jit_device) + self.analyzed_results[C.F_dfxdx] = jax.jit(dfx, device=self.jit_device) return self.analyzed_results[C.F_dfxdx] @property @@ -307,7 +308,7 @@ def F_vmap_fp_aux(self): # --- # "X": a two-dimensional matrix: (num_batch, num_var) # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) - self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux)) + self.analyzed_results[C.F_vmap_fp_aux] = jax.jit(vmap(self.F_fixed_point_aux)) return self.analyzed_results[C.F_vmap_fp_aux] @property @@ -326,7 +327,7 @@ def F_vmap_fp_opt(self): # --- # "X": a two-dimensional matrix: (num_batch, num_var) # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) - self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt)) + self.analyzed_results[C.F_vmap_fp_opt] = jax.jit(vmap(self.F_fixed_point_opt)) return self.analyzed_results[C.F_vmap_fp_opt] def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None): @@ -519,7 +520,7 @@ def F_y_by_x_in_fy(self): @property def F_vmap_fy(self): if C.F_vmap_fy not in self.analyzed_results: - self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device) + self.analyzed_results[C.F_vmap_fy] = jax.jit(vmap(self.F_fy), device=self.jit_device) return self.analyzed_results[C.F_vmap_fy] @property @@ -527,7 +528,7 @@ def F_dfxdy(self): """The function to evaluate :math:`\frac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" if C.F_dfxdy not in self.analyzed_results: dfxdy = bm.vector_grad(self.F_fx, argnums=1) - self.analyzed_results[C.F_dfxdy] = bm.jit(dfxdy, device=self.jit_device) + self.analyzed_results[C.F_dfxdy] = jax.jit(dfxdy, device=self.jit_device) return self.analyzed_results[C.F_dfxdy] @property @@ -535,7 +536,7 @@ def F_dfydx(self): """The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" if C.F_dfydx not in self.analyzed_results: dfydx = bm.vector_grad(self.F_fy, argnums=0) - self.analyzed_results[C.F_dfydx] = bm.jit(dfydx, device=self.jit_device) + self.analyzed_results[C.F_dfydx] = jax.jit(dfydx, device=self.jit_device) return self.analyzed_results[C.F_dfydx] @property @@ -543,7 +544,7 @@ def F_dfydy(self): """The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" if C.F_dfydy not in self.analyzed_results: dfydy = bm.vector_grad(self.F_fy, argnums=1) - self.analyzed_results[C.F_dfydy] = bm.jit(dfydy, device=self.jit_device) + self.analyzed_results[C.F_dfydy] = jax.jit(dfydy, device=self.jit_device) return self.analyzed_results[C.F_dfydy] @property @@ -556,7 +557,7 @@ def f_jacobian(*var_and_pars): def call(*var_and_pars): var_and_pars = tuple((vp.value if isinstance(vp, bm.Array) else vp) for vp in var_and_pars) - return jnp.array(bm.jit(f_jacobian, device=self.jit_device)(*var_and_pars)) + return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars)) self.analyzed_results[C.F_jacobian] = call return self.analyzed_results[C.F_jacobian] @@ -681,7 +682,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux if self.F_x_by_y_in_fx is not None: utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...") - vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device) + vmap_f = jax.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device) for j, pars in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") mesh_values = jnp.meshgrid(*((ys,) + pars)) @@ -697,7 +698,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux elif self.F_y_by_x_in_fx is not None: utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...") - vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device) + vmap_f = jax.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device) for j, pars in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") mesh_values = jnp.meshgrid(*((xs,) + pars)) @@ -715,9 +716,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux utils.output("I am evaluating fx-nullcline by optimization ...") # auxiliary functions f2 = lambda y, x, *pars: self.F_fx(x, y, *pars) - vmap_f2 = bm.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) + vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) # num segments for _j, Ps in enumerate(par_seg): @@ -774,7 +775,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux if self.F_x_by_y_in_fy is not None: utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...") - vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device) + vmap_f = jax.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device) for j, pars in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") mesh_values = jnp.meshgrid(*((ys,) + pars)) @@ -790,7 +791,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux elif self.F_y_by_x_in_fy is not None: utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...") - vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device) + vmap_f = jax.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device) for j, pars in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") mesh_values = jnp.meshgrid(*((xs,) + pars)) @@ -809,9 +810,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux # auxiliary functions f2 = lambda y, x, *pars: self.F_fy(x, y, *pars) - vmap_f2 = bm.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) + vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) for j, Ps in enumerate(par_seg): if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") @@ -859,7 +860,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100): xs = self.resolutions[self.x_var] ys = self.resolutions[self.y_var] P = tuple(self.resolutions[p] for p in self.target_par_names) - f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) + f_select = jax.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) # num seguments if isinstance(num_segments, int): @@ -939,10 +940,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, if self.convert_type() == C.x_by_y: num_seg = len(self.resolutions[self.y_var]) - f_vmap = bm.jit(vmap(self.F_y_convert[1])) + f_vmap = jax.jit(vmap(self.F_y_convert[1])) else: num_seg = len(self.resolutions[self.x_var]) - f_vmap = bm.jit(vmap(self.F_x_convert[1])) + f_vmap = jax.jit(vmap(self.F_x_convert[1])) # get the signs signs = jnp.sign(f_vmap(candidates, *args)) signs = signs.reshape((num_seg, -1)) @@ -972,10 +973,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, # get another value if self.convert_type() == C.x_by_y: y_values = fps - x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args) + x_values = jax.jit(vmap(self.F_y_convert[0]))(y_values, *args) else: x_values = fps - y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args) + y_values = jax.jit(vmap(self.F_x_convert[0]))(x_values, *args) fps = jnp.stack([x_values, y_values]).T return fps, selected_ids, args @@ -1042,7 +1043,7 @@ def F_fz(self): wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) f = wrapper(self.model.f_derivatives[self.z_var]) f = partial(f, **(self.pars_update + self.fixed_vars)) - self.analyzed_results[C.F_fz] = bm.jit(f, device=self.jit_device) + self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device) return self.analyzed_results[C.F_fz] def fz_signs(self, pars=(), cache=False): diff --git a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py index 076b03c48..1ab064855 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py @@ -44,7 +44,7 @@ def __init__(self, model, target_pars, target_vars, fixed_vars=None, @property def F_vmap_dfxdx(self): if C.F_vmap_dfxdx not in self.analyzed_results: - f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) + f = jax.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) self.analyzed_results[C.F_vmap_dfxdx] = f return self.analyzed_results[C.F_vmap_dfxdx] @@ -163,7 +163,7 @@ def F_vmap_jacobian(self): if C.F_vmap_jacobian not in self.analyzed_results: f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), self.F_fy(xy[0], xy[1], *args)]) - f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device) + f2 = jax.jit(vmap(bm.jacobian(f1)), device=self.jit_device) self.analyzed_results[C.F_vmap_jacobian] = f2 return self.analyzed_results[C.F_vmap_jacobian] diff --git a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py index 667c62ec8..8a2aceaee 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py @@ -160,7 +160,7 @@ def __init__(self, @property def F_vmap_brentq_fy(self): if C.F_vmap_brentq_fy not in self.analyzed_results: - f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy))) + f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy))) self.analyzed_results[C.F_vmap_brentq_fy] = f_opt return self.analyzed_results[C.F_vmap_brentq_fy] diff --git a/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py b/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py index 3b5117273..5b666fbbe 100644 --- a/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py +++ b/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py @@ -7,7 +7,7 @@ import jax.numpy as jnp -block = False +show = False class TestPhasePlane(unittest.TestCase): @@ -27,7 +27,8 @@ def int_x(x, t, Iext): plt.ion() analyzer.plot_vector_field() analyzer.plot_fixed_point() - plt.show(block=block) + if show: + plt.show() plt.close() bp.math.disable_x64() @@ -74,6 +75,7 @@ def int_s2(s2, t, s1): analyzer.plot_vector_field() analyzer.plot_nullcline(coords=dict(s2='s2-s1')) analyzer.plot_fixed_point() - plt.show(block=block) + if show: + plt.show() plt.close() bp.math.disable_x64() diff --git a/brainpy/_src/analysis/utils/optimization.py b/brainpy/_src/analysis/utils/optimization.py index b452d90ca..270852327 100644 --- a/brainpy/_src/analysis/utils/optimization.py +++ b/brainpy/_src/analysis/utils/optimization.py @@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()): def brentq_roots(f, starts, ends, *vmap_args, args=()): in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) - vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes)) + vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes)) all_args = vmap_args + args if len(all_args): res = vmap_f_opt(starts, ends, all_args) @@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()): return fps starts = candidates[candidate_ids] ends = candidates[candidate_ids + 1] - f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) + f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) res = f_opt(starts, ends, args) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] fps2 = res['root'][valid_idx] @@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()): def roots_of_1d_by_xy(f, starts, ends, args): f = f_without_jaxarray_return(f) - f_opt = bm.jit(vmap(jax_brentq(f))) + f_opt = jax.jit(vmap(jax_brentq(f))) res = f_opt(starts, ends, (args,)) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] xs = res['root'][valid_idx] diff --git a/brainpy/_src/analysis/utils/others.py b/brainpy/_src/analysis/utils/others.py index 12411c9ab..27a592b81 100644 --- a/brainpy/_src/analysis/utils/others.py +++ b/brainpy/_src/analysis/utils/others.py @@ -2,9 +2,9 @@ from typing import Union, Dict +import jax import jax.numpy as jnp import numpy as np -from jax import vmap from jax.tree_util import tree_map import brainpy.math as bm @@ -80,7 +80,7 @@ def get_sign(f, xs, ys): def get_sign2(f, *xyz, args=()): in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) - f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes)) + f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) xyz = tuple((v.value if isinstance(v, bm.Array) else v) for v in xyz) XYZ = jnp.meshgrid(*xyz) XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index d7adb4e0e..8f0f9dbd2 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -668,13 +668,9 @@ def _get_f_predict(self, shared_args: Dict = None): shared_kwargs_str = serialize_kwargs(shared_args) if shared_kwargs_str not in self._f_predict_compiled: - dyn_vars = self.target.vars() - dyn_vars.update(self._dyn_vars) - dyn_vars.update(self.vars(level=0)) - dyn_vars = dyn_vars.unique() if self._memory_efficient: - _jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars) + _jit_step = bm.jit(partial(self._step_func_predict, shared_args)) def run_func(all_inputs): outs = None @@ -688,12 +684,10 @@ def run_func(all_inputs): return outs, None else: - @bm.jit(dyn_vars=dyn_vars) + step = partial(self._step_func_predict, shared_args) + def run_func(all_inputs): - return bm.for_loop(partial(self._step_func_predict, shared_args), - all_inputs, - dyn_vars=dyn_vars, - jit=self.jit['predict']) + return bm.for_loop(step, all_inputs, jit=self.jit['predict']) self._f_predict_compiled[shared_kwargs_str] = run_func diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index 3013db176..4bbdcc629 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -977,8 +977,7 @@ def update(self, tdi): inp = bm.cond((a > 5) * (b > 5), lambda _: self.rng.normal(a, b * p, self.target_var.shape), lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape), - None, - dyn_vars=self.rng) + None) self.target_var += inp * self.weight def __repr__(self): diff --git a/brainpy/_src/dyn/synapses_v2/others.py b/brainpy/_src/dyn/synapses_v2/others.py index e21d9f881..1defe3782 100644 --- a/brainpy/_src/dyn/synapses_v2/others.py +++ b/brainpy/_src/dyn/synapses_v2/others.py @@ -68,8 +68,7 @@ def update(self): inp = bm.cond((a > 5) * (b > 5), lambda _: self.rng.normal(a, b * p, self.target_shape), lambda _: self.rng.binomial(self.num_input, p, self.target_shape), - None, - dyn_vars=self.rng) + None) return inp * self.weight def __repr__(self): diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/dyn/transform.py index fdd93a1ed..d3263a5da 100644 --- a/brainpy/_src/dyn/transform.py +++ b/brainpy/_src/dyn/transform.py @@ -278,7 +278,7 @@ def __call__( else: shared = tools.DotDict() - if self.t0 is not None: + if self.t0 is not None: shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value if self.i0 is not None: shared['i'] = jnp.arange(0, length[0]) + self.i0.value @@ -286,7 +286,6 @@ def __call__( assert not self.no_state results = bm.for_loop(functools.partial(self._run, self.shared_arg), (shared, xs), - child_objs=(self.target, share), jit=self.jit, remat=self.remat) if self.i0 is not None: diff --git a/brainpy/_src/encoding/stateful_encoding.py b/brainpy/_src/encoding/stateful_encoding.py index 299960e1d..b40e4f427 100644 --- a/brainpy/_src/encoding/stateful_encoding.py +++ b/brainpy/_src/encoding/stateful_encoding.py @@ -84,7 +84,7 @@ def f(i): inputs.value -= w * spike return spike - return bm.for_loop(f, bm.arange(num_step).value, dyn_vars=inputs) + return bm.for_loop(f, bm.arange(num_step).value) class LatencyEncoder(Encoder): diff --git a/brainpy/_src/inputs/currents.py b/brainpy/_src/inputs/currents.py index dbaf57956..e91149572 100644 --- a/brainpy/_src/inputs/currents.py +++ b/brainpy/_src/inputs/currents.py @@ -309,7 +309,7 @@ def _f(t): x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n) return x.value - noises = bm.for_loop(_f, jnp.arange(t_start, t_end, dt), dyn_vars=[x, rng]) + noises = bm.for_loop(_f, jnp.arange(t_start, t_end, dt)) t_end = duration if t_end is None else t_end i_start = int(t_start / dt) diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index c7add518d..74dd01dcc 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -283,7 +283,6 @@ class ExponentialEuler(ODEIntegrator): The default numerical integration step. name : optional, str The integrator name. - dyn_vars : optional, dict, sequence of ArrayType, ArrayType """ def __init__( @@ -293,7 +292,6 @@ def __init__( dt=None, name=None, show_code=False, - dyn_vars=None, state_delays=None, neutral_delays=None ): @@ -308,7 +306,6 @@ def __init__( if var_type == C.SYSTEM_VAR: raise NotImplementedError(f'{self.__class__.__name__} does not support {C.SYSTEM_VAR}, ' f'because the auto-differentiation ') - self.dyn_vars = dyn_vars # build the integrator self.code_lines = [] @@ -356,7 +353,7 @@ def _build_integrator(self, eq): eq=str(eq))) # gradient function - value_and_grad = bm.vector_grad(eq, argnums=0, dyn_vars=self.dyn_vars, return_value=True) + value_and_grad = bm.vector_grad(eq, argnums=0, return_value=True) # integration function def integral(*args, **kwargs): diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py index 79f755ab4..08a7a5936 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py @@ -4,6 +4,7 @@ import numpy as np +import jax import brainpy.math as bm from brainpy._src.integrators.ode import explicit_rk plt = None @@ -27,7 +28,7 @@ def run_integrator(method, show=False): if plt is None: import matplotlib.pyplot as plt - f_integral = bm.jit(method(f_lorenz, dt=dt)) + f_integral = jax.jit(method(f_lorenz, dt=dt)) x = bm.Variable(bm.ones(1)) y = bm.Variable(bm.ones(1)) z = bm.Variable(bm.ones(1)) diff --git a/brainpy/_src/integrators/sde/base.py b/brainpy/_src/integrators/sde/base.py index a2a7abaa9..d624dcfb7 100644 --- a/brainpy/_src/integrators/sde/base.py +++ b/brainpy/_src/integrators/sde/base.py @@ -36,9 +36,7 @@ def __init__( intg_type: str = None, wiener_type: str = None, state_delays: Dict[str, AbstractDelay] = None, - dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, ): - self.dyn_vars = dyn_vars dt = bm.get_dt() if dt is None else dt parses = utils.get_args(f) variables = parses[0] # variable names, (before 't') diff --git a/brainpy/_src/integrators/sde/normal.py b/brainpy/_src/integrators/sde/normal.py index fb01f7cbb..66e1ea4f0 100644 --- a/brainpy/_src/integrators/sde/normal.py +++ b/brainpy/_src/integrators/sde/normal.py @@ -84,13 +84,12 @@ class Euler(SDEIntegrator): def __init__( self, f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None, - state_delays=None, dyn_vars=None + state_delays=None, ): super(Euler, self).__init__(f=f, g=g, dt=dt, name=name, var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, - state_delays=state_delays, - dyn_vars=dyn_vars) + state_delays=state_delays) self.set_integral(self.step) @@ -209,14 +208,13 @@ class Heun(Euler): def __init__(self, f, g, dt=None, name=None, show_code=False, var_type=None, intg_type=None, wiener_type=None, - state_delays=None, dyn_vars=None): + state_delays=None, ): if intg_type != constants.STRA_SDE: raise errors.IntegratorError(f'Heun method only supports Stranovich ' f'integral of SDEs, but we got {intg_type} integral.') super(Heun, self).__init__(f=f, g=g, dt=dt, name=name, var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, state_delays=state_delays, - dyn_vars=dyn_vars) + wiener_type=wiener_type, state_delays=state_delays) register_sde_integrator('heun', Heun) @@ -259,7 +257,6 @@ def __init__( intg_type: str = None, wiener_type: str = None, state_delays: Dict[str, bm.AbstractDelay] = None, - dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, ): super(Milstein, self).__init__(f=f, g=g, @@ -268,8 +265,7 @@ def __init__( var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, - state_delays=state_delays, - dyn_vars=dyn_vars) + state_delays=state_delays) self.set_integral(self.step) def _get_g_grad(self, f, allow_raise=False, need_grad=True): @@ -296,7 +292,7 @@ def _get_g_grad(self, f, allow_raise=False, need_grad=True): if not allow_raise: raise e if need_grad: - res[0] = bm.vector_grad(f, argnums=0, dyn_vars=self.dyn_vars) + res[0] = bm.vector_grad(f, argnums=0) return [tuple(res)], state def step(self, *args, **kwargs): @@ -416,7 +412,6 @@ def __init__( intg_type: str = None, wiener_type: str = None, state_delays: Dict[str, bm.AbstractDelay] = None, - dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, ): super(MilsteinGradFree, self).__init__(f=f, g=g, @@ -425,8 +420,7 @@ def __init__( var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, - state_delays=state_delays, - dyn_vars=dyn_vars) + state_delays=state_delays) self.set_integral(self.step) def step(self, *args, **kwargs): @@ -558,7 +552,6 @@ def __init__( var_type=var_type, intg_type=intg_type, wiener_type=wiener_type, - dyn_vars=dyn_vars, state_delays=state_delays) if self.intg_type == constants.STRA_SDE: @@ -626,7 +619,7 @@ def _build_integrator(self, f): if len(vars) != 1: raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__, vars=str(vars), eq=str(f))) - value_and_grad = bm.vector_grad(f, argnums=0, dyn_vars=self.dyn_vars, return_value=True) + value_and_grad = bm.vector_grad(f, argnums=0, return_value=True) # integration function def integral(*args, **kwargs): diff --git a/brainpy/_src/integrators/sde/tests/test_normal.py b/brainpy/_src/integrators/sde/tests/test_normal.py index f4fc14cf1..5a15a9680 100644 --- a/brainpy/_src/integrators/sde/tests/test_normal.py +++ b/brainpy/_src/integrators/sde/tests/test_normal.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt from brainpy._src.integrators.sde.normal import ExponentialEuler -block = False +show = False class TestExpEuler(unittest.TestCase): @@ -33,7 +33,9 @@ def lorenz_g(x, y, z, t, **kwargs): runner.run(100.) plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - plt.show(block=block) + if show: + plt.show() + plt.close() def test2(self): p = 0.1 @@ -41,8 +43,8 @@ def test2(self): def lorenz_g(x, y, z, t, **kwargs): return bp.math.asarray([p * x, p2 * x]), \ - bp.math.asarray([p * y, p2 * y]), \ - bp.math.asarray([p * z, p2 * z]) + bp.math.asarray([p * y, p2 * y]), \ + bp.math.asarray([p * z, p2 * z]) dx = lambda x, t, y, sigma=10: sigma * (y - x) dy = lambda y, t, x, z, rho=28: x * (rho - z) - y @@ -54,8 +56,8 @@ def lorenz_g(x, y, z, t, **kwargs): wiener_type=bp.integrators.VECTOR_WIENER, var_type=bp.integrators.POP_VAR, show_code=True) - runner = bp.integrators.IntegratorRunner(intg, monitors=['x', 'y', 'z'], - dt=0.001, inits=[1., 1., 0.], jit=False) + runner = bp.IntegratorRunner(intg, monitors=['x', 'y', 'z'], + dt=0.001, inits=[1., 1., 0.], jit=False) with self.assertRaises(ValueError): runner.run(100.) @@ -65,8 +67,8 @@ def test3(self): def lorenz_g(x, y, z, t, **kwargs): return bp.math.asarray([p * x, p2 * x]).T, \ - bp.math.asarray([p * y, p2 * y]).T, \ - bp.math.asarray([p * z, p2 * z]).T + bp.math.asarray([p * y, p2 * y]).T, \ + bp.math.asarray([p * z, p2 * z]).T dx = lambda x, t, y, sigma=10: sigma * (y - x) dy = lambda y, t, x, z, rho=28: x * (rho - z) - y @@ -78,15 +80,17 @@ def lorenz_g(x, y, z, t, **kwargs): wiener_type=bp.integrators.VECTOR_WIENER, var_type=bp.integrators.POP_VAR, show_code=True) - runner = bp.integrators.IntegratorRunner(intg, - monitors=['x', 'y', 'z'], - dt=0.001, - inits=[1., 1., 0.], - jit=True) + runner = bp.IntegratorRunner(intg, + monitors=['x', 'y', 'z'], + dt=0.001, + inits=[1., 1., 0.], + jit=True) runner.run(100.) plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - plt.show(block=block) + if show: + plt.show() + plt.close() class TestMilstein(unittest.TestCase): @@ -110,11 +114,14 @@ def test1(self): wiener_type=bp.integrators.SCALAR_WIENER, var_type=bp.integrators.POP_VAR, method='milstein') - runner = bp.integrators.IntegratorRunner(intg, - monitors=['x', 'y', 'z'], - dt=0.001, inits=[1., 1., 0.], - jit=True) + runner = bp.IntegratorRunner(intg, + monitors=['x', 'y', 'z'], + dt=0.001, inits=[1., 1., 0.], + jit=True) runner.run(100.) plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - plt.show(block=block) + if show: + plt.show() + plt.close() + diff --git a/brainpy/_src/math/object_transform/_tools.py b/brainpy/_src/math/object_transform/_tools.py index cd58c5433..4ffacc6c2 100644 --- a/brainpy/_src/math/object_transform/_tools.py +++ b/brainpy/_src/math/object_transform/_tools.py @@ -1,8 +1,54 @@ import warnings +from functools import wraps +from typing import Sequence + import jax -from brainpy._src.math.object_transform.variables import VariableStack + from brainpy._src.math.object_transform.naming import (cache_stack, get_stack_cache) +from brainpy._src.math.object_transform.variables import VariableStack + + +class Empty(object): + pass + + +empty = Empty() + + +def _partial_fun(fun, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = ()): + static_args, dyn_args = [], [] + for i, arg in enumerate(args): + if i in static_argnums: + static_args.append(arg) + else: + static_args.append(empty) + dyn_args.append(arg) + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames + + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + args = [] + i = 0 + for arg in static_args: + if arg == empty: + args.append(dynargs[i]) + i += 1 + else: + args.append(arg) + return fun(*args, **static_kwargs, **dynkwargs) + + return new_fun, dyn_args, dyn_kwargs def dynvar_deprecation(dyn_vars=None): @@ -30,16 +76,22 @@ def abstract(x): return jax.api_util.shaped_abstractify(x) -def evaluate_dyn_vars(f, *args, **kwargs): +def evaluate_dyn_vars(f, + *args, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = (), + **kwargs): # TODO: better way for cache mechanism stack = get_stack_cache(f) if stack is None: + if len(static_argnums) or len(static_argnames): + f2, args, kwargs = _partial_fun(f, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + else: + f2, args, kwargs = f, args, kwargs with jax.ensure_compile_time_eval(): - args, kwargs = jax.tree_util.tree_map(abstract, (args, kwargs)) with VariableStack() as stack: - _ = jax.eval_shape(f, *args, **kwargs) + _ = jax.eval_shape(f2, *args, **kwargs) cache_stack(f, stack) # cache + del args, kwargs, f2 return stack - - diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 71f764bae..700a99cf0 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -440,11 +440,11 @@ def cond( >>> def true_f(_): a.value += 1 >>> def false_f(_): b.value -= 1 >>> - >>> bm.cond(True, true_f, false_f, dyn_vars=[a, b]) + >>> bm.cond(True, true_f, false_f) >>> a, b Variable([1., 1.], dtype=float32), Variable([1., 1.], dtype=float32) >>> - >>> bm.cond(False, true_f, false_f, dyn_vars=[a, b]) + >>> bm.cond(False, true_f, false_f) >>> a, b Variable([1., 1.], dtype=float32), Variable([0., 0.], dtype=float32) @@ -595,14 +595,10 @@ def ifelse( # format new codes if len(conditions) == 1: - if len(dyn_vars) > 0: - return cond(conditions[0], - branches[0], - branches[1], - operands, - dyn_vars) - else: - return lax.cond(conditions[0], branches[0], branches[1], operands) + return cond(conditions[0], + branches[0], + branches[1], + operands) else: code_scope = {'conditions': conditions, 'branches': branches} codes = ['def f(operands):', diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 9f6d85646..e0b06aaad 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -13,9 +13,9 @@ import jax from brainpy import tools, check -from .naming import get_stack_cache, cache_stack -from ._tools import dynvar_deprecation, node_deprecation, evaluate_dyn_vars, abstract +from ._tools import dynvar_deprecation, node_deprecation, evaluate_dyn_vars, _partial_fun from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from .variables import Variable, VariableStack __all__ = [ @@ -23,6 +23,30 @@ ] +def _seq_of_int(static_argnums): + if static_argnums is None: + static_argnums = () + elif isinstance(static_argnums, int): + static_argnums = (static_argnums,) + elif isinstance(static_argnums, (tuple, list)): + pass + else: + raise TypeError('static_argnums must be None, int, or sequence of int.') + return static_argnums + + +def _seq_of_str(static_argnames): + if static_argnames is None: + static_argnames = () + elif isinstance(static_argnames, str): + static_argnames = (static_argnames,) + elif isinstance(static_argnames, (tuple, list)): + pass + else: + raise TypeError('static_argnums must be None, str, or sequence of str.') + return static_argnames + + class JITTransform(ObjectTransform): """Object-oriented JIT transformation in BrainPy.""" @@ -58,8 +82,8 @@ def __init__( # parameters self._backend = backend - self._static_argnums = static_argnums - self._static_argnames = static_argnames + self._static_argnums = _seq_of_int(static_argnums) + self._static_argnames = _seq_of_str(static_argnames) self._donate_argnums = donate_argnums self._device = device self._inline = inline @@ -82,7 +106,11 @@ def __call__(self, *args, **kwargs): return self.fun(*args, **kwargs) if self._transform is None: - self._dyn_vars = evaluate_dyn_vars(self.fun, *args, **kwargs) + self._dyn_vars = evaluate_dyn_vars(self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs) self._transform = jax.jit( self._transform_function, static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), @@ -183,6 +211,7 @@ def jit( You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. + >>> import brainpy as bp >>> class Hello(bp.BrainPyObject): >>> def __init__(self): >>> super(Hello, self).__init__() @@ -270,15 +299,16 @@ def cls_jit( keep_unused: bool = False, abstracted_axes: Optional[Any] = None, ) -> Callable: - """Just-in-time compile a function and then the jitted function as a bound method for a class. + """Just-in-time compile a function and then the jitted function as the bound method for a class. Examples -------- This transformation can be put on any class function. For example, + >>> import brainpy as bp >>> import brainpy.math as bm - >>> + >>> >>> class SomeProgram(bp.BrainPyObject): >>> def __init__(self): >>> super(SomeProgram, self).__init__() @@ -290,8 +320,7 @@ def cls_jit( >>> a = bm.random.uniform(size=2) >>> a = a.at[0].set(1.) >>> self.b += a - >>> return self.b - >>> + >>> >>> program = SomeProgram() >>> program() @@ -334,17 +363,25 @@ def _make_jit_fun( keep_unused: bool = False, abstracted_axes: Optional[Any] = None, ): + static_argnums = _seq_of_int(static_argnums) + static_argnames = _seq_of_int(static_argnames) + @wraps(fun) def call_fun(self, *args, **kwargs): fun2 = partial(fun, self) if jax.config.jax_disable_jit: return fun2(*args, **kwargs) - cache = get_stack_cache(fun2) # TODO: better cache mechanism + + hash_v = hash(fun) + hash(self) + cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: with jax.ensure_compile_time_eval(): - args_, kwargs_ = jax.tree_util.tree_map(abstract, (args, kwargs)) + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 with VariableStack() as stack: - _ = jax.eval_shape(fun2, *args_, **kwargs_) + _ = jax.eval_shape(fun3, *args_, **kwargs_) del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), @@ -355,7 +392,8 @@ def call_fun(self, *args, **kwargs): keep_unused=keep_unused, abstracted_axes=abstracted_axes ) - cache_stack(fun2, (stack, _transform)) # cache + cache_stack(hash_v, (stack, _transform)) # cache "variable stack" and "transform function" + else: stack, _transform = cache del cache @@ -368,6 +406,8 @@ def call_fun(self, *args, **kwargs): def _make_transform(fun, stack): + + @wraps(fun) def _transform_function(variable_data: dict, *args, **kwargs): for key, v in stack.items(): v._value = variable_data[key] diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index 87fc1b913..66499954a 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -100,17 +100,18 @@ def __call__(self): bm.random.seed(0) t = Test() - f_grad = bm.grad(t, grad_vars=t.vars()) + f_grad = bm.grad(t, grad_vars={'a': t.a, 'b': t.b, 'c': t.c}) grads = f_grad() - for g in grads.values(): assert (g == 1.).all() + for g in grads.values(): + assert (g == 1.).all() t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars()) + f_grad = bm.grad(t, grad_vars=[t.a, t.b]) grads = f_grad() for g in grads: assert (g == 1.).all() t = Test() - f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars()) + f_grad = bm.grad(t, grad_vars=t.a) grads = f_grad() assert (grads == 1.).all() @@ -127,14 +128,14 @@ def __call__(self): bm.random.seed(0) t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), has_aux=True) + f_grad = bm.grad(t, grad_vars=[t.a, t.b], has_aux=True) grads, aux = f_grad() for g in grads: assert (g == 1.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) t = Test() - f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), has_aux=True) + f_grad = bm.grad(t, grad_vars=t.a, has_aux=True) grads, aux = f_grad() assert (grads == 1.).all() assert aux[0] == bm.sin(100) @@ -153,13 +154,13 @@ def __call__(self): bm.random.seed(0) t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), return_value=True) + f_grad = bm.grad(t, grad_vars=[t.a, t.b], return_value=True) grads, returns = f_grad() for g in grads: assert (g == 1.).all() assert returns == t() t = Test() - f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), return_value=True) + f_grad = bm.grad(t, grad_vars=t.a, return_value=True) grads, returns = f_grad() assert (grads == 1.).all() assert returns == t() @@ -177,7 +178,7 @@ def __call__(self): bm.random.seed(0) t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), + f_grad = bm.grad(t, grad_vars=[t.a, t.b], has_aux=True, return_value=True) grads, returns, aux = f_grad() for g in grads: assert (g == 1.).all() @@ -186,7 +187,7 @@ def __call__(self): assert aux[1] == bm.exp(0.1) t = Test() - f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), + f_grad = bm.grad(t, grad_vars=t.a, has_aux=True, return_value=True) grads, returns, aux = f_grad() assert (grads == 1.).all() @@ -221,12 +222,12 @@ def __call__(self, d): assert (arg_grads[0] == 2.).all() t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0) + f_grad = bm.grad(t, argnums=0) arg_grads = f_grad(bm.random.random(10)) assert (arg_grads == 2.).all() t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0]) + f_grad = bm.grad(t, argnums=[0]) arg_grads = f_grad(bm.random.random(10)) assert (arg_grads[0] == 2.).all() @@ -260,14 +261,14 @@ def __call__(self, d): assert aux[1] == bm.exp(0.1) t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True) + f_grad = bm.grad(t, argnums=0, has_aux=True) arg_grads, aux = f_grad(bm.random.random(10)) assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True) + f_grad = bm.grad(t, argnums=[0], has_aux=True) arg_grads, aux = f_grad(bm.random.random(10)) assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) @@ -304,14 +305,14 @@ def __call__(self, d): assert loss == t(d) t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, return_value=True) + f_grad = bm.grad(t, argnums=0, return_value=True) d = bm.random.random(10) arg_grads, loss = f_grad(d) assert (arg_grads == 2.).all() assert loss == t(d) t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], return_value=True) + f_grad = bm.grad(t, argnums=[0], return_value=True) d = bm.random.random(10) arg_grads, loss = f_grad(d) assert (arg_grads[0] == 2.).all() @@ -351,7 +352,7 @@ def __call__(self, d): assert loss == t(d)[0] t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True) + f_grad = bm.grad(t, argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads == 2.).all() @@ -360,7 +361,7 @@ def __call__(self, d): assert loss == t(d)[0] t = Test() - f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) + f_grad = bm.grad(t, argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads[0] == 2.).all() @@ -383,12 +384,12 @@ def f(x): f2 = lambda x: x ** 3 self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) - self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) - self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) + self.assertEqual(jax.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) + self.assertEqual(jax.jit(_jacfwd(jax.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) - self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) - self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) + self.assertEqual(jax.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) + self.assertEqual(jax.jit(_jacfwd(jax.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) @@ -397,12 +398,12 @@ def f(x): f2 = lambda x: x ** 3 * bm.sin(x) self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) - self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) - self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) + self.assertEqual(jax.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) + self.assertEqual(jax.jit(_jacfwd(jax.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) - self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) - self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) + self.assertEqual(jax.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) + self.assertEqual(jax.jit(_jacfwd(jax.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) def test_jacrev1(self): def f1(x, y): diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index bd6c09f90..4dd12d4d7 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -11,7 +11,7 @@ import brainpy.math as bm -class TestLoop(jtu.JaxTestCase): +class TestLoop(parameterized.TestCase): def test_make_loop(self): def make_node(v1, v2): def update(x): diff --git a/brainpy/_src/math/object_transform/tests/test_jit.py b/brainpy/_src/math/object_transform/tests/test_jit.py index f8691c80d..2467f2d1a 100644 --- a/brainpy/_src/math/object_transform/tests/test_jit.py +++ b/brainpy/_src/math/object_transform/tests/test_jit.py @@ -42,6 +42,28 @@ def __call__(self, *args, **kwargs): self.assertTrue(bm.array_equal(b_out, program.b)) print(b_out) + def test_jit_with_static(self): + a = bm.Variable(bm.ones(2)) + + @bm.jit(static_argnums=1) + def f(b, c): + a.value *= b + a.value /= c + + f(1., 2.) + self.assertTrue(bm.allclose(a.value, 0.5)) + + @bm.jit(static_argnames=['c']) + def f2(b, c): + a.value *= b + a.value /= c + + f2(2., c=1.) + self.assertTrue(bm.allclose(a.value, 1.)) + + +class TestClsJIT(bp.testing.UnitTestCase): + def test_class_jit1(self): class SomeProgram(bp.BrainPyObject): def __init__(self): @@ -92,5 +114,30 @@ def update(self, x): program.update(1.) self.assertTrue(bm.allclose(new_b + 1., program.b)) + def test_cls_jit_with_static(self): + class MyObj: + def __init__(self): + self.a = bm.Variable(bm.ones(2)) + + @bm.cls_jit(static_argnums=1) + def f(self, b, c): + self.a.value *= b + self.a.value /= c + + obj = MyObj() + obj.f(1., 2.) + self.assertTrue(bm.allclose(obj.a.value, 0.5)) + + class MyObj2: + def __init__(self): + self.a = bm.Variable(bm.ones(2)) + + @bm.cls_jit(static_argnames=['c']) + def f(self, b, c): + self.a.value *= b + self.a.value /= c + obj = MyObj2() + obj.f(1., c=2.) + self.assertTrue(bm.allclose(obj.a.value, 0.5)) diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 1a8b9a06e..001d2d1c7 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -156,7 +156,7 @@ def __init__( f'but the batch axis is set to be {batch_axis}.') # ready to trace the variable - self._ready_to_trace = _ready_to_trace + self._ready_to_trace = _ready_to_trace and len(var_stack_list) == 0 @property def nobatch_shape(self) -> TupleType[int, ...]: diff --git a/brainpy/_src/measure/correlation.py b/brainpy/_src/measure/correlation.py index 689c8c0c3..9e3dd9d0a 100644 --- a/brainpy/_src/measure/correlation.py +++ b/brainpy/_src/measure/correlation.py @@ -89,7 +89,7 @@ def _f(i, j): lambda _: 0., lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, None) - res = bm.for_loop(_f, dyn_vars=[], operands=indices) + res = bm.for_loop(_f, operands=indices) elif method == 'vmap': @vmap @@ -178,7 +178,6 @@ def voltage_fluctuation(potentials, numpy=True, method='loop'): if method == 'loop': _var = lambda aa: bm.for_loop(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2, - dyn_vars=(), operands=jnp.moveaxis(aa, 0, 1)) elif method == 'vmap': diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index 083169806..3d4b6f4cb 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -436,11 +436,7 @@ def _get_f_loss(self, shared_args=None, jit=True) -> Callable: if shared_args_str not in self._f_loss_compiled: self._f_loss_compiled[shared_args_str] = partial(self._step_func_loss, shared_args) if self.jit[c.LOSS_PHASE] and jit: - dyn_vars = self.target.vars() - dyn_vars.update(self._dyn_vars) - dyn_vars.update(self.vars(level=0)) - self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str], - dyn_vars=dyn_vars.unique()) + self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str]) return self._f_loss_compiled[shared_args_str] def _get_f_grad(self, shared_args=None) -> Callable: @@ -450,10 +446,8 @@ def _get_f_grad(self, shared_args=None) -> Callable: _f_loss_internal = self._get_f_loss(shared_args, jit=False) dyn_vars = self.target.vars() dyn_vars.update(self._dyn_vars) - dyn_vars = dyn_vars.unique() - tran_vars = dyn_vars.subset(bm.TrainVar) + tran_vars = dyn_vars.subset(bm.TrainVar).unique() grad_f = bm.grad(_f_loss_internal, - dyn_vars=dyn_vars, grad_vars=tran_vars, return_value=True, has_aux=self.loss_has_aux) @@ -478,8 +472,7 @@ def _get_f_train(self, shared_args=None) -> Callable: dyn_vars.update(self._dyn_vars) dyn_vars.update(self.vars(level=0)) dyn_vars = dyn_vars.unique() - self._f_fit_compiled[shared_args_str] = bm.jit(self._f_fit_compiled[shared_args_str], - dyn_vars=dyn_vars) + self._f_fit_compiled[shared_args_str] = bm.jit(self._f_fit_compiled[shared_args_str]) return self._f_fit_compiled[shared_args_str] def _step_func_loss(self, shared_args, inputs, targets): @@ -602,11 +595,7 @@ def _get_f_predict(self, shared_args: Dict = None, jit: bool = True): self._f_predict_compiled[shared_args_str] = partial(self._step_func_predict, shared_args) if self.jit[c.PREDICT_PHASE] and jit: - dyn_vars = self.target.vars() - dyn_vars.update(self._dyn_vars) - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - self._f_predict_compiled[shared_args_str] = bm.jit(self._f_predict_compiled[shared_args_str], - dyn_vars=dyn_vars.unique()) + self._f_predict_compiled[shared_args_str] = bm.jit(self._f_predict_compiled[shared_args_str]) return self._f_predict_compiled[shared_args_str] def predict( diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 69d534e4a..1433a5e22 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -207,7 +207,7 @@ def _get_f_train(self, shared_args: Dict = None) -> Callable: self._f_fit_compiled[shared_kwargs_str] = ( self._fun_train if self.jit['fit'] else - bm.jit(self._fun_train, dyn_vars=self.vars().unique()) + bm.jit(self._fun_train) ) return self._f_fit_compiled[shared_kwargs_str] diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 9b21b25fe..0438ab001 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -239,14 +239,11 @@ def _get_fit_func(self, shared_args: Dict = None): if shared_args is None: shared_args = dict() shared_kwargs_str = serialize_kwargs(shared_args) if shared_kwargs_str not in self._f_fit_compiled: - dyn_vars = self.vars().unique() - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - + @bm.jit def run_func(all_inputs): - with jax.disable_jit(not self.jit['fit']): - return bm.for_loop(partial(self._step_func_fit, shared_args), - all_inputs, - dyn_vars=dyn_vars) + return bm.for_loop(partial(self._step_func_fit, shared_args), + all_inputs, + jit=self.jit['fit']) self._f_fit_compiled[shared_kwargs_str] = run_func return self._f_fit_compiled[shared_kwargs_str] diff --git a/brainpy/math/object_base.py b/brainpy/math/object_base.py index 1f19459c0..34561e011 100644 --- a/brainpy/math/object_base.py +++ b/brainpy/math/object_base.py @@ -2,6 +2,7 @@ from brainpy._src.math.object_transform.base import (BrainPyObject as BrainPyObject, FunAsObject as FunAsObject) +from brainpy._src.math.object_transform.function import (Partial as Partial) from brainpy._src.math.object_transform.base import (NodeList as NodeList, NodeDict as NodeDict,) from brainpy._src.math.object_transform.variables import (Variable as Variable, diff --git a/brainpy/math/object_transform.py b/brainpy/math/object_transform.py index 9bcf8c763..d281ec740 100644 --- a/brainpy/math/object_transform.py +++ b/brainpy/math/object_transform.py @@ -22,11 +22,11 @@ from brainpy._src.math.object_transform.jit import ( jit as jit, - cls_jit, + cls_jit as cls_jit, ) from brainpy._src.math.object_transform.function import ( to_object as to_object, function as function, -) \ No newline at end of file +) diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py index 2eaba0d2b..bd5180728 100644 --- a/examples/dynamics_training/Song_2016_EI_RNN.py +++ b/examples/dynamics_training/Song_2016_EI_RNN.py @@ -107,14 +107,13 @@ def loss(self, xs, ys): # gradient function grad_f = bm.grad(net.loss, - child_objs=net, grad_vars=net.train_vars().unique(), return_value=True, has_aux=True) # training function -@bm.jit(child_objs=(net, opt)) +@bm.jit def train(xs, ys): grads, loss, acc = grad_f(xs, ys) opt.update(grads) diff --git a/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py b/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py index 310c144d4..047f11141 100644 --- a/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py +++ b/examples/dynamics_training/Sussillo_Abbott_2009_FORCE_Learning.py @@ -78,14 +78,14 @@ def rls(self, target): self.w_ro += dw def simulate(self, xs): - return bm.for_loop(self.update, dyn_vars=self.vars(), operands=xs) + return bm.for_loop(self.update, operands=xs) def train(self, xs, targets): def _f(x, target): r, o = self.update(x) self.rls(target) return r, o - return bm.for_loop(_f, dyn_vars=self.vars(), operands=[xs, targets]) + return bm.for_loop(_f, operands=[xs, targets]) # %% diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index 706e51bd6..9e3c318ff 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -10,7 +10,7 @@ num_batch = 128 -@bm.jit(static_argnames=['batch_size'], dyn_vars=bm.random.DEFAULT) +@bm.jit(static_argnames=['batch_size']) def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): # Create the white noise input sample = bm.random.normal(size=(batch_size, 1, 1)) diff --git a/examples/dynamics_training/reservoir-mnist.py b/examples/dynamics_training/reservoir-mnist.py index a868b8bf8..6216325ff 100644 --- a/examples/dynamics_training/reservoir-mnist.py +++ b/examples/dynamics_training/reservoir-mnist.py @@ -46,8 +46,7 @@ def offline_train(num_hidden=2000, num_in=28, num_out=10): ) preds = bm.for_loop(lambda x: jnp.argmax(esn({}, x), axis=-1), - x_train, - child_objs=esn) + x_train) accuracy = jnp.mean(preds == jnp.repeat(traindata.targets, x_train.shape[1])) print(accuracy) @@ -73,7 +72,7 @@ def force_online_train(num_hidden=2000, num_in=28, num_out=10, train_stage='fina rls = bp.algorithms.RLS() rls.register_target(num_hidden) - @bm.jit(child_objs=(reservoir, readout, rls)) + @bm.jit def train_step(xs, y): reservoir.reset_state(xs.shape[0]) if train_stage == 'final_step': @@ -91,7 +90,7 @@ def train_step(xs, y): else: raise ValueError - @bm.jit(child_objs=(reservoir, readout)) + @bm.jit def predict(xs): reservoir.reset_state(xs.shape[0]) for x in xs.transpose(1, 0, 2): diff --git a/examples/training_ann_models/mnist_ResNet.py b/examples/training_ann_models/mnist_ResNet.py index 690e16aea..9a74ddbb9 100644 --- a/examples/training_ann_models/mnist_ResNet.py +++ b/examples/training_ann_models/mnist_ResNet.py @@ -212,7 +212,6 @@ def main(): net = ResNet18(num_classes=10) # loss function - @bm.to_object(child_objs=net) def loss_fun(X, Y, fit=True): s = {'fit': fit} predictions = net(s, X) @@ -227,13 +226,12 @@ def loss_fun(X, Y, fit=True): train_vars=net.train_vars().unique()) @bm.jit - @bm.to_object(child_objs=(grad_fun, optimizer)) def train_fun(X, Y): grads, l, n = grad_fun(X, Y) optimizer.update(grads) return l, n - predict_loss_fun = bm.jit(partial(loss_fun, fit=False), child_objs=loss_fun) + predict_loss_fun = bm.jit(partial(loss_fun, fit=False)) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: diff --git a/examples/training_snn_models/spikebased_bp_for_cifar10.py b/examples/training_snn_models/spikebased_bp_for_cifar10.py index 48c93b871..91e98abb1 100644 --- a/examples/training_snn_models/spikebased_bp_for_cifar10.py +++ b/examples/training_snn_models/spikebased_bp_for_cifar10.py @@ -239,7 +239,7 @@ def main(): bm.random.seed(1234) net = ResNet11() - @bm.jit(child_objs=net, dyn_vars=bm.random.DEFAULT) + @bm.jit def loss_fun(x, y, fit=True): bp.share.save(fit=fit) yy = bm.one_hot(y, 10, dtype=bm.float_) @@ -262,7 +262,7 @@ def loss_fun(x, y, fit=True): train_vars=net.train_vars().unique(), weight_decay=5e-4) - @bm.jit(child_objs=(optimizer, grad_fun)) + @bm.jit def train_fun(x, y): grads, l, n = grad_fun(x, y) optimizer.update(grads)