Skip to content

Commit

Permalink
[oo transform] enable new style of jit transformation to support stat…
Browse files Browse the repository at this point in the history
…ic_argnums and static_argnames (#360)

[oo transform] enable new style of jit transformation to support `static_argnums` and `static_argnames`
  • Loading branch information
ztqakita authored Apr 11, 2023
2 parents 5874e8d + ba72bba commit 3d63531
Show file tree
Hide file tree
Showing 37 changed files with 303 additions and 197 deletions.
4 changes: 1 addition & 3 deletions brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def f_loss():

grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
dyn_vars = optimizer.vars() + (fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
dyn_vars = dyn_vars.unique()

def train(idx):
gradients, loss = grad_f()
Expand All @@ -353,7 +351,7 @@ def train(idx):
return loss

def batch_train(start_i, n_batch):
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch), dyn_vars=dyn_vars)
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch))

# Run the optimization
if self.verbose:
Expand Down
53 changes: 27 additions & 26 deletions brainpy/_src/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial

import numpy as np
import jax
from jax import numpy as jnp
from jax import vmap
from jax.scipy.optimize import minimize
Expand Down Expand Up @@ -274,21 +275,21 @@ def F_fx(self):
f = partial(f, **(self.pars_update + self.fixed_vars))
f = utils.f_without_jaxarray_return(f)
f = utils.remove_return_shape(f)
self.analyzed_results[C.F_fx] = bm.jit(f, device=self.jit_device)
self.analyzed_results[C.F_fx] = jax.jit(f, device=self.jit_device)
return self.analyzed_results[C.F_fx]

@property
def F_vmap_fx(self):
if C.F_vmap_fx not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
self.analyzed_results[C.F_vmap_fx] = jax.jit(vmap(self.F_fx), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fx]

@property
def F_dfxdx(self):
"""The function to evaluate :math:`\frac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
if C.F_dfxdx not in self.analyzed_results:
dfx = bm.vector_grad(self.F_fx, argnums=0)
self.analyzed_results[C.F_dfxdx] = bm.jit(dfx, device=self.jit_device)
self.analyzed_results[C.F_dfxdx] = jax.jit(dfx, device=self.jit_device)
return self.analyzed_results[C.F_dfxdx]

@property
Expand All @@ -307,7 +308,7 @@ def F_vmap_fp_aux(self):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
self.analyzed_results[C.F_vmap_fp_aux] = jax.jit(vmap(self.F_fixed_point_aux))
return self.analyzed_results[C.F_vmap_fp_aux]

@property
Expand All @@ -326,7 +327,7 @@ def F_vmap_fp_opt(self):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
self.analyzed_results[C.F_vmap_fp_opt] = jax.jit(vmap(self.F_fixed_point_opt))
return self.analyzed_results[C.F_vmap_fp_opt]

def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
Expand Down Expand Up @@ -519,31 +520,31 @@ def F_y_by_x_in_fy(self):
@property
def F_vmap_fy(self):
if C.F_vmap_fy not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
self.analyzed_results[C.F_vmap_fy] = jax.jit(vmap(self.F_fy), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fy]

@property
def F_dfxdy(self):
"""The function to evaluate :math:`\frac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
if C.F_dfxdy not in self.analyzed_results:
dfxdy = bm.vector_grad(self.F_fx, argnums=1)
self.analyzed_results[C.F_dfxdy] = bm.jit(dfxdy, device=self.jit_device)
self.analyzed_results[C.F_dfxdy] = jax.jit(dfxdy, device=self.jit_device)
return self.analyzed_results[C.F_dfxdy]

@property
def F_dfydx(self):
"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`."""
if C.F_dfydx not in self.analyzed_results:
dfydx = bm.vector_grad(self.F_fy, argnums=0)
self.analyzed_results[C.F_dfydx] = bm.jit(dfydx, device=self.jit_device)
self.analyzed_results[C.F_dfydx] = jax.jit(dfydx, device=self.jit_device)
return self.analyzed_results[C.F_dfydx]

@property
def F_dfydy(self):
"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`."""
if C.F_dfydy not in self.analyzed_results:
dfydy = bm.vector_grad(self.F_fy, argnums=1)
self.analyzed_results[C.F_dfydy] = bm.jit(dfydy, device=self.jit_device)
self.analyzed_results[C.F_dfydy] = jax.jit(dfydy, device=self.jit_device)
return self.analyzed_results[C.F_dfydy]

@property
Expand All @@ -556,7 +557,7 @@ def f_jacobian(*var_and_pars):

def call(*var_and_pars):
var_and_pars = tuple((vp.value if isinstance(vp, bm.Array) else vp) for vp in var_and_pars)
return jnp.array(bm.jit(f_jacobian, device=self.jit_device)(*var_and_pars))
return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars))

self.analyzed_results[C.F_jacobian] = call
return self.analyzed_results[C.F_jacobian]
Expand Down Expand Up @@ -681,7 +682,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

if self.F_x_by_y_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
vmap_f = jax.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
Expand All @@ -697,7 +698,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

elif self.F_y_by_x_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
vmap_f = jax.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
Expand All @@ -715,9 +716,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
utils.output("I am evaluating fx-nullcline by optimization ...")
# auxiliary functions
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
vmap_f2 = jax.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)

# num segments
for _j, Ps in enumerate(par_seg):
Expand Down Expand Up @@ -774,7 +775,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

if self.F_x_by_y_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
vmap_f = jax.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
Expand All @@ -790,7 +791,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

elif self.F_y_by_x_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
vmap_f = jax.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
Expand All @@ -809,9 +810,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

# auxiliary functions
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
vmap_f2 = jax.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)

for j, Ps in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
Expand Down Expand Up @@ -859,7 +860,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
xs = self.resolutions[self.x_var]
ys = self.resolutions[self.y_var]
P = tuple(self.resolutions[p] for p in self.target_par_names)
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
f_select = jax.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))

# num seguments
if isinstance(num_segments, int):
Expand Down Expand Up @@ -939,10 +940,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,

if self.convert_type() == C.x_by_y:
num_seg = len(self.resolutions[self.y_var])
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
f_vmap = jax.jit(vmap(self.F_y_convert[1]))
else:
num_seg = len(self.resolutions[self.x_var])
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
f_vmap = jax.jit(vmap(self.F_x_convert[1]))
# get the signs
signs = jnp.sign(f_vmap(candidates, *args))
signs = signs.reshape((num_seg, -1))
Expand Down Expand Up @@ -972,10 +973,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
# get another value
if self.convert_type() == C.x_by_y:
y_values = fps
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
x_values = jax.jit(vmap(self.F_y_convert[0]))(y_values, *args)
else:
x_values = fps
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
y_values = jax.jit(vmap(self.F_x_convert[0]))(x_values, *args)
fps = jnp.stack([x_values, y_values]).T
return fps, selected_ids, args

Expand Down Expand Up @@ -1042,7 +1043,7 @@ def F_fz(self):
wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names)
f = wrapper(self.model.f_derivatives[self.z_var])
f = partial(f, **(self.pars_update + self.fixed_vars))
self.analyzed_results[C.F_fz] = bm.jit(f, device=self.jit_device)
self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device)
return self.analyzed_results[C.F_fz]

def fz_signs(self, pars=(), cache=False):
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, model, target_pars, target_vars, fixed_vars=None,
@property
def F_vmap_dfxdx(self):
if C.F_vmap_dfxdx not in self.analyzed_results:
f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
f = jax.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
self.analyzed_results[C.F_vmap_dfxdx] = f
return self.analyzed_results[C.F_vmap_dfxdx]

Expand Down Expand Up @@ -163,7 +163,7 @@ def F_vmap_jacobian(self):
if C.F_vmap_jacobian not in self.analyzed_results:
f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
self.F_fy(xy[0], xy[1], *args)])
f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
f2 = jax.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
self.analyzed_results[C.F_vmap_jacobian] = f2
return self.analyzed_results[C.F_vmap_jacobian]

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self,
@property
def F_vmap_brentq_fy(self):
if C.F_vmap_brentq_fy not in self.analyzed_results:
f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy)))
f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy)))
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
return self.analyzed_results[C.F_vmap_brentq_fy]

Expand Down
8 changes: 5 additions & 3 deletions brainpy/_src/analysis/lowdim/tests/test_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp


block = False
show = False


class TestPhasePlane(unittest.TestCase):
Expand All @@ -27,7 +27,8 @@ def int_x(x, t, Iext):
plt.ion()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
plt.show(block=block)
if show:
plt.show()
plt.close()
bp.math.disable_x64()

Expand Down Expand Up @@ -74,6 +75,7 @@ def int_s2(s2, t, s1):
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(s2='s2-s1'))
analyzer.plot_fixed_point()
plt.show(block=block)
if show:
plt.show()
plt.close()
bp.math.disable_x64()
6 changes: 3 additions & 3 deletions brainpy/_src/analysis/utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()):

def brentq_roots(f, starts, ends, *vmap_args, args=()):
in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes))
vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes))
all_args = vmap_args + args
if len(all_args):
res = vmap_f_opt(starts, ends, all_args)
Expand Down Expand Up @@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
return fps
starts = candidates[candidate_ids]
ends = candidates[candidate_ids + 1]
f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
res = f_opt(starts, ends, args)
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
fps2 = res['root'][valid_idx]
Expand All @@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()):

def roots_of_1d_by_xy(f, starts, ends, args):
f = f_without_jaxarray_return(f)
f_opt = bm.jit(vmap(jax_brentq(f)))
f_opt = jax.jit(vmap(jax_brentq(f)))
res = f_opt(starts, ends, (args,))
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
xs = res['root'][valid_idx]
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/utils/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from typing import Union, Dict

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.tree_util import tree_map

import brainpy.math as bm
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_sign(f, xs, ys):

def get_sign2(f, *xyz, args=()):
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes))
f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
xyz = tuple((v.value if isinstance(v, bm.Array) else v) for v in xyz)
XYZ = jnp.meshgrid(*xyz)
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
Expand Down
14 changes: 4 additions & 10 deletions brainpy/_src/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,9 @@ def _get_f_predict(self, shared_args: Dict = None):

shared_kwargs_str = serialize_kwargs(shared_args)
if shared_kwargs_str not in self._f_predict_compiled:
dyn_vars = self.target.vars()
dyn_vars.update(self._dyn_vars)
dyn_vars.update(self.vars(level=0))
dyn_vars = dyn_vars.unique()

if self._memory_efficient:
_jit_step = bm.jit(partial(self._step_func_predict, shared_args), dyn_vars=dyn_vars)
_jit_step = bm.jit(partial(self._step_func_predict, shared_args))

def run_func(all_inputs):
outs = None
Expand All @@ -688,12 +684,10 @@ def run_func(all_inputs):
return outs, None

else:
@bm.jit(dyn_vars=dyn_vars)
step = partial(self._step_func_predict, shared_args)

def run_func(all_inputs):
return bm.for_loop(partial(self._step_func_predict, shared_args),
all_inputs,
dyn_vars=dyn_vars,
jit=self.jit['predict'])
return bm.for_loop(step, all_inputs, jit=self.jit['predict'])

self._f_predict_compiled[shared_kwargs_str] = run_func

Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,7 @@ def update(self, tdi):
inp = bm.cond((a > 5) * (b > 5),
lambda _: self.rng.normal(a, b * p, self.target_var.shape),
lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape),
None,
dyn_vars=self.rng)
None)
self.target_var += inp * self.weight

def __repr__(self):
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/dyn/synapses_v2/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def update(self):
inp = bm.cond((a > 5) * (b > 5),
lambda _: self.rng.normal(a, b * p, self.target_shape),
lambda _: self.rng.binomial(self.num_input, p, self.target_shape),
None,
dyn_vars=self.rng)
None)
return inp * self.weight

def __repr__(self):
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/dyn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,14 @@ def __call__(

else:
shared = tools.DotDict()
if self.t0 is not None:
if self.t0 is not None:
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
if self.i0 is not None:
shared['i'] = jnp.arange(0, length[0]) + self.i0.value

assert not self.no_state
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
(shared, xs),
child_objs=(self.target, share),
jit=self.jit,
remat=self.remat)
if self.i0 is not None:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/encoding/stateful_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def f(i):
inputs.value -= w * spike
return spike

return bm.for_loop(f, bm.arange(num_step).value, dyn_vars=inputs)
return bm.for_loop(f, bm.arange(num_step).value)


class LatencyEncoder(Encoder):
Expand Down
Loading

0 comments on commit 3d63531

Please sign in to comment.