Skip to content

Commit

Permalink
bug: jax item assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
olive004 committed Oct 31, 2024
1 parent dd9c6d3 commit ea218a9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions notebooks/24_autodiff.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@
"idxs_output = np.array([species.index(s) for s in species_output])\n",
"idxs_unbound = np.array([species.index(s) for s in species_unbound])\n",
"idxs_bound = np.array([species.index(s) for s in species_bound])\n",
"signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n",
"signal_onehot = jnp.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n",
"signal_onehot_inv = (signal_onehot == 0) * 1.0\n",
"\n",
"# Circuit parameters\n",
"n_circuits = 20\n",
Expand Down Expand Up @@ -229,8 +230,8 @@
"def simulate(y0, params):\n",
" \"\"\" (y11, ddys1, y_sens, y_prec) = ys1 \"\"\"\n",
" ts0, ys0 = jax.vmap(sim_func)(y0=y0, reverse_rates=params)\n",
" y01 = np.array(ys0[:, -1])\n",
" y01[:, idxs_signal] = signal_target * y01[:, idxs_signal]\n",
" y01 = (jnp.array(ys0[:, -1]) * signal_onehot * signal_target) + (jnp.array(ys0[:, -1]) * signal_onehot_inv)\n",
" # y01[:, idxs_signal] = signal_target * y01[:, idxs_signal]\n",
" ts1, ys1 = jax.vmap(simulate_signal)(y0=y01, reverse_rates=params)\n",
" # (y11, ddys1, y_sens, y_prec) = ys1\n",
" return (ts0, ys0), (ts1, ys1)"
Expand Down Expand Up @@ -305,7 +306,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit ea218a9

Please sign in to comment.