From 8f0c8fbd97938b53735a5e6d391b29355f189de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sat, 19 Oct 2024 11:11:11 +0100 Subject: [PATCH] update jax notebook (#2552) * update jax notebook * Apply suggestions from code review Co-authored-by: Daniel Weindl --------- Co-authored-by: Daniel Weindl --- python/examples/example_jax/ExampleJax.ipynb | 174 ++++++++++--------- 1 file changed, 89 insertions(+), 85 deletions(-) diff --git a/python/examples/example_jax/ExampleJax.ipynb b/python/examples/example_jax/ExampleJax.ipynb index 931dfb7e28..1d7d0967e1 100644 --- a/python/examples/example_jax/ExampleJax.ipynb +++ b/python/examples/example_jax/ExampleJax.ipynb @@ -16,7 +16,12 @@ "cell_type": "code", "execution_count": 1, "id": "b0a66e18", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-18T20:35:50.848939Z", + "start_time": "2024-10-18T20:35:43.411198Z" + } + }, "outputs": [], "source": [ "import jax\n", @@ -45,10 +50,15 @@ "cell_type": "code", "execution_count": 2, "id": "9166e3bf", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-18T20:36:51.490345Z", + "start_time": "2024-10-18T20:35:50.853467Z" + } + }, "outputs": [], "source": [ - "import petab\n", + "import petab.v1 as petab\n", "\n", "model_name = \"Boehm_JProteomeRes2014\"\n", "yaml_file = f\"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", @@ -67,7 +77,12 @@ "cell_type": "code", "execution_count": 3, "id": "b04ca561", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-18T20:36:51.740603Z", + "start_time": "2024-10-18T20:36:51.725877Z" + } + }, "outputs": [ { "data": { @@ -166,7 +181,7 @@ " ratio\n", " ratio\n", " lin\n", - " -5.00000\n", + " 0.00000\n", " 5\n", " 0.693000\n", " 0\n", @@ -212,15 +227,15 @@ "" ], "text/plain": [ - " parameterName parameterScale lowerBound \n", + " parameterName parameterScale lowerBound \\\n", "parameterId \n", - "Epo_degradation_BaF3 EPO_{degradation,BaF3} log10 0.00001 \\\n", + "Epo_degradation_BaF3 EPO_{degradation,BaF3} log10 0.00001 \n", "k_exp_hetero k_{exp,hetero} log10 0.00001 \n", "k_exp_homo k_{exp,homo} log10 0.00001 \n", "k_imp_hetero k_{imp,hetero} log10 0.00001 \n", "k_imp_homo k_{imp,homo} log10 0.00001 \n", "k_phos k_{phos} log10 0.00001 \n", - "ratio ratio lin -5.00000 \n", + "ratio ratio lin 0.00000 \n", "sd_pSTAT5A_rel \\sigma_{pSTAT5A,rel} log10 0.00001 \n", "sd_pSTAT5B_rel \\sigma_{pSTAT5B,rel} log10 0.00001 \n", "sd_rSTAT5A_rel \\sigma_{rSTAT5A,rel} log10 0.00001 \n", @@ -262,14 +277,19 @@ "cell_type": "code", "execution_count": 4, "id": "6ada3fb8", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "start_time": "2024-10-18T20:36:51.765461Z" + }, + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "from amici.petab.petab_import import import_petab_problem\n", "\n", - "amici_model = import_petab_problem(\n", - " petab_problem, compile_=True, verbose=False\n", - ")" + "amici_model = import_petab_problem(petab_problem, compile_=True, verbose=False)" ] }, { @@ -279,7 +299,7 @@ "source": [ "## JAX implementation\n", "\n", - "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, will interface AMICI using the experimental jax module [host_callback](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)." + "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, we will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)." ] }, { @@ -287,7 +307,7 @@ "id": "6bbf2f06", "metadata": {}, "source": [ - "To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via [simulate_petab](https://amici.readthedocs.io/en/latest/generated/amici.petab_objective.html#amici.petab_objective.simulate_petab). To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to `simulate_petab`. Moreover, `simulate_petab` expects a dictionary of parameters, so we create a dictionary using the free parameter ids from the petab problem. As we want to implement parameter transformation in JAX, we disable parameter scaling in petab by passing `scaled_parameters=True`." + "To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via [simulate_petab](https://amici.readthedocs.io/en/latest/generated/amici.petab_objective.html#amici.petab_objective.simulate_petab). To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to `simulate_petab`. Moreover, `simulate_petab` expects a dictionary of parameters, so we create a dictionary using the free parameter ids from the petab problem." ] }, { @@ -304,13 +324,18 @@ "amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", "\n", "\n", - "def amici_hcb_base(parameters: jnp.array):\n", - " return simulate_petab(\n", + "def amici_callback_base(parameters: jnp.array):\n", + " ret = simulate_petab(\n", " petab_problem,\n", " amici_model,\n", " problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),\n", " solver=amici_solver,\n", - " )" + " )\n", + " llh = np.array(ret[\"llh\"])\n", + " sllh = np.array(\n", + " tuple(ret[\"sllh\"][par_id] for par_id in petab_problem.x_free_ids)\n", + " )\n", + " return llh, sllh" ] }, { @@ -318,7 +343,7 @@ "id": "6f6201e8", "metadata": {}, "source": [ - "Now we can use this base function to create two separate functions that compute the log-likelihood (`llh`) and its gradient (`sllh`) in two individual routines. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important." + "Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and its gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important." ] }, { @@ -328,14 +353,25 @@ "metadata": {}, "outputs": [], "source": [ - "def amici_hcb_llh(parameters: jnp.array):\n", - " return amici_hcb_base(parameters)[\"llh\"]\n", + "def device_fun_llh(x: jnp.array):\n", + " return jax.pure_callback(\n", + " lambda x: amici_callback_base(x)[0],\n", + " jax.ShapeDtypeStruct((), x.dtype),\n", + " x,\n", + " )\n", "\n", "\n", - "def amici_hcb_sllh(parameters: jnp.array):\n", - " sllh = amici_hcb_base(parameters)[\"sllh\"]\n", - " return jnp.asarray(\n", - " tuple(sllh[par_id] for par_id in petab_problem.x_free_ids)\n", + "def device_fun_llh_sllh(x: jnp.array):\n", + " return jax.pure_callback(\n", + " amici_callback_base,\n", + " (\n", + " jax.ShapeDtypeStruct((), x.dtype),\n", + " jax.ShapeDtypeStruct(\n", + " x.shape,\n", + " x.dtype,\n", + " ),\n", + " ),\n", + " x,\n", " )" ] }, @@ -344,7 +380,7 @@ "id": "98e819bd", "metadata": {}, "source": [ - "Now we can finally define the JAX function that runs amici simulation using the host callback. We add a `custom_jvp` decorator so that we can define a custom jacobian vector product function in the next step. More details about custom jacobian vector product functions can be found in the [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)" + "Even though the two functions that we just defined are valid jax functions, they can't compute derivatives yet. To support derivative computation, we have to define a new function with a `jax.custom_jvp` decorator, which specifies that we will define a custom jacobian vector product (jvp) function, as well as the corresponding jvp using the `@jax_objective.defjvp` decorator. More details about custom jacobian vector product functions can be found in the [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)" ] }, { @@ -354,49 +390,17 @@ "metadata": {}, "outputs": [], "source": [ - "import jax.experimental.host_callback as hcb\n", - "from jax import custom_jvp\n", - "\n", - "import numpy as np\n", + "@jax.custom_jvp\n", + "def jax_objective(parameters: jnp.array):\n", + " return device_fun_llh(parameters)\n", "\n", "\n", - "@custom_jvp\n", - "def jax_objective(parameters: jnp.array):\n", - " return hcb.call(\n", - " amici_hcb_llh,\n", - " parameters,\n", - " result_shape=jax.ShapeDtypeStruct((), np.float64),\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "c75535a5", - "metadata": {}, - "source": [ - "Now we define the function that implement the jacobian vector product. This effectively just returns the objective function value (computed using the previously defined `jax_objective`) as well as the inner product of the gradient (computed using a host callback to the previously defined `amici_hcb_sllh`) and the tangents vector. Note that this implementation performs two simulation runs, one for the function value and one for the gradient, which is inefficient and could be avoided by caching solutions." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5a68c812", - "metadata": {}, - "outputs": [], - "source": [ "@jax_objective.defjvp\n", "def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):\n", " (parameters,) = primals\n", " (x_dot,) = tangents\n", - " llh = jax_objective(parameters)\n", - " sllh = hcb.call(\n", - " amici_hcb_sllh,\n", - " parameters,\n", - " result_shape=jax.ShapeDtypeStruct(\n", - " (petab_problem.parameter_df.estimate.sum(),), np.float64\n", - " ),\n", - " )\n", - " return llh, sllh.dot(x_dot)" + " llh, sllh = device_fun_llh_sllh(parameters)\n", + " return llh, sllh @ x_dot" ] }, { @@ -404,25 +408,23 @@ "id": "379485ca", "metadata": {}, "source": [ - "As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the `value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just in time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call." + "As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just-in-time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "3ab8fde9", "metadata": {}, "outputs": [], "source": [ - "from jax import value_and_grad\n", - "\n", "parameter_scales = petab_problem.parameter_df.loc[\n", " petab_problem.x_free_ids, petab.PARAMETER_SCALE\n", "].values\n", "\n", "\n", "@jax.jit\n", - "@value_and_grad\n", + "@jax.value_and_grad\n", "def jax_objective_with_parameter_transform(parameters: jnp.array):\n", " par_scaled = jnp.asarray(\n", " tuple(\n", @@ -449,11 +451,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "b7e9ff3b", "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", + "\n", "parameters = dict(zip(petab_problem.x_free_ids, petab_problem.x_nominal_free))\n", "scaled_parameters = petab_problem.scale_parameters(parameters)\n", "scaled_parameters_np = np.asarray(list(scaled_parameters.values()))" @@ -461,13 +465,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "fb3085a8", "metadata": {}, "outputs": [], "source": [ "llh_jax, sllh_jax = jax_objective_with_parameter_transform(\n", - " scaled_parameters_np\n", + " jnp.array(scaled_parameters_np)\n", ")" ] }, @@ -481,7 +485,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "48451b0e", "metadata": {}, "outputs": [], @@ -498,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "2628db12", "metadata": {}, "outputs": [ @@ -544,7 +548,7 @@ "llh -138.221997 -138.222 -2.135248e-08" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -564,7 +568,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "0846523f", "metadata": {}, "outputs": [ @@ -617,7 +621,7 @@ " k_imp_hetero\n", " -0.005414\n", " -0.005403\n", - " 1.973517e-03\n", + " 1.973604e-03\n", " \n", " \n", " k_imp_homo\n", @@ -635,7 +639,7 @@ " sd_pSTAT5A_rel\n", " -0.010784\n", " -0.010800\n", - " -1.469604e-03\n", + " -1.469518e-03\n", " \n", " \n", " sd_pSTAT5B_rel\n", @@ -658,15 +662,15 @@ "Epo_degradation_BaF3 -0.022045 -0.022034 4.645833e-04\n", "k_exp_hetero -0.055323 -0.055323 8.646725e-08\n", "k_exp_homo -0.005789 -0.005801 -2.013520e-03\n", - "k_imp_hetero -0.005414 -0.005403 1.973517e-03\n", + "k_imp_hetero -0.005414 -0.005403 1.973604e-03\n", "k_imp_homo 0.000045 0.000045 1.119566e-06\n", "k_phos -0.007907 -0.007794 1.447768e-02\n", - "sd_pSTAT5A_rel -0.010784 -0.010800 -1.469604e-03\n", + "sd_pSTAT5A_rel -0.010784 -0.010800 -1.469518e-03\n", "sd_pSTAT5B_rel -0.024037 -0.024037 -8.729860e-06\n", "sd_rSTAT5A_rel -0.019191 -0.019186 2.829431e-04" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -692,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "5f81c693", "metadata": {}, "outputs": [], @@ -713,7 +717,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "25e8b301", "metadata": {}, "outputs": [ @@ -759,7 +763,7 @@ "llh -138.221997 -138.221997 -0.0" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -777,7 +781,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "f31a3927", "metadata": {}, "outputs": [ @@ -879,7 +883,7 @@ "sd_rSTAT5A_rel -0.019191 -0.019191 -0.0" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -911,7 +915,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.13.0" } }, "nbformat": 4,