Skip to content

Commit

Permalink
neural ode example
Browse files Browse the repository at this point in the history
  • Loading branch information
Marmaduke Woodman committed Nov 22, 2024
1 parent 2dace19 commit d97735e
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions examples/neural-ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ def dfun1(x, p):
c, p = p
return vb.mpr_dfun(x, (c,0), p)

dt = 0.1
dt = 0.05
nt = 400
_, loop = vb.make_ode(dt, dfun1)

def run_it(pars):
c, tau, r0 = pars
rv0 = jp.r_[r0, -2.0]
p = c, vb.mpr_default_theta._replace(tau=tau)
nt = 200
rvs = loop(rv0, jp.r_[:nt], p)
return rvs

run_them = jax.jit(jax.vmap(run_it))
ng = 4j
ng = 16j
cs,taus,r0s=jp.mgrid[0.0:2.0:ng, 1.0:3.0:ng, 0.001:1.0:ng]
pars = jp.c_[cs.ravel(), taus.ravel(), r0s.ravel()]
rvs = run_them(pars)
Expand All @@ -33,9 +33,9 @@ def run_it(pars):
pl.semilogy(rvs[:,:,0].T, 'k.-', alpha=0.1)
pl.grid(1)
pl.savefig('scratch.jpg')

# 1/0
# setup a neural ode for this use case
wb, mlp = vb.make_dense_layers(2+2, latent_dims=[16,16], out_dim=2)
wb, mlp = vb.make_dense_layers(2+2, latent_dims=[128,128], out_dim=2, init_scl=1e-1)
def dfun2(x, wb_pars):
wb, pars = wb_pars
x_ = jp.vstack((x, pars[:, :2].T)) # only c & tau
Expand All @@ -46,27 +46,28 @@ def dfun2(x, wb_pars):

_, mlploop = vb.make_ode(dt, dfun2)
rvs_ = rvs.transpose(1,2,0)

def loss(wb):
# init with r0s
x0 = rvs[:,0].T.at[0].set(pars[:,2])
xs = mlploop(x0, jp.r_[:200], (wb, pars))
r0 = r0s.ravel()
x0 = jp.array([r0, -2*jp.ones_like(r0)])
xs = mlploop(x0, jp.r_[:nt], (wb, pars))
e = xs - rvs_
l_x = jp.mean(jp.square(e))
l_dx = jp.mean(jp.square(jp.diff(e, axis=0)))
return l_x# + l_dx
l_dx = jp.mean(jp.square(jp.diff(xs, axis=0) - jp.diff(rvs_, axis=0)))
return l_x

vg = jax.jit(jax.value_and_grad(loss))

print(vg(wb)[0])

# now let's descend the gradient
from jax.example_libraries.optimizers import adam
oinit, oupdate, oget = adam(1e-2)
oinit, oupdate, oget = adam(1e-4)
owb = oinit(wb)
for i in (pbar := tqdm.trange(300)):
for i in (pbar := tqdm.trange(1000)):
v, g = vg(oget(owb))
owb = oupdate(i, g, owb)
pbar.set_description(f'loss {v:0.2f}')
pbar.set_description(f'loss {v:0.5f}')

# since it's trained, we can check how well it works,
test_pars = jp.array([
Expand All @@ -76,10 +77,23 @@ def loss(wb):
rv_test = run_them(test_pars) # (len(test_pars), 200, 2)
x0_test = jp.array([test_pars[:,2],
jp.ones(len(test_pars))*-2.0])
x_test = mlploop(x0_test, jp.r_[:200], (oget(owb), test_pars))
print(x0_test)
x_test = mlploop(x0_test, jp.r_[:nt], (oget(owb), test_pars))

# show that
pl.figure()
pl.subplot(121)
pl.plot(rv_test[:, :, 0].T, 'k')
pl.plot(x_test[:, 0, :], 'r')
pl.subplot(122)
pl.plot(rv_test[:, :, 1].T, 'k')
pl.plot(x_test[:, 1, :], 'r')
pl.savefig('scratch.jpg')


# two more steps: run a full TVB style simulation


# lastly, run a parameter sweep


0 comments on commit d97735e

Please sign in to comment.