From ea218a95a87a62ebefeed1db88e1202ae358ea0f Mon Sep 17 00:00:00 2001 From: olive004 Date: Thu, 31 Oct 2024 23:00:05 +0000 Subject: [PATCH] bug: jax item assignment --- notebooks/24_autodiff.ipynb | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/notebooks/24_autodiff.ipynb b/notebooks/24_autodiff.ipynb index b36fc1a3..39477cdc 100644 --- a/notebooks/24_autodiff.ipynb +++ b/notebooks/24_autodiff.ipynb @@ -119,7 +119,8 @@ "idxs_output = np.array([species.index(s) for s in species_output])\n", "idxs_unbound = np.array([species.index(s) for s in species_unbound])\n", "idxs_bound = np.array([species.index(s) for s in species_bound])\n", - "signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n", + "signal_onehot = jnp.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])\n", + "signal_onehot_inv = (signal_onehot == 0) * 1.0\n", "\n", "# Circuit parameters\n", "n_circuits = 20\n", @@ -229,8 +230,8 @@ "def simulate(y0, params):\n", " \"\"\" (y11, ddys1, y_sens, y_prec) = ys1 \"\"\"\n", " ts0, ys0 = jax.vmap(sim_func)(y0=y0, reverse_rates=params)\n", - " y01 = np.array(ys0[:, -1])\n", - " y01[:, idxs_signal] = signal_target * y01[:, idxs_signal]\n", + " y01 = (jnp.array(ys0[:, -1]) * signal_onehot * signal_target) + (jnp.array(ys0[:, -1]) * signal_onehot_inv)\n", + " # y01[:, idxs_signal] = signal_target * y01[:, idxs_signal]\n", " ts1, ys1 = jax.vmap(simulate_signal)(y0=y01, reverse_rates=params)\n", " # (y11, ddys1, y_sens, y_prec) = ys1\n", " return (ts0, ys0), (ts1, ys1)" @@ -305,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [