From 6718a1e3648d46693b11094ff15575abc5ab11a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 2 Dec 2024 21:14:07 +0000 Subject: [PATCH] remove PK --- python/sdist/amici/jax/ode_export.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index cb7a299a8b..6ef7c2b9c1 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, ) -from itertools import chain import sympy as sp @@ -202,7 +201,7 @@ def _generate_jax_code(self) -> None: "x_rdata", "total_cl", ) - sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata") + sym_names = ("p", "x", "tcl", "w", "my", "y", "sigmay", "x_rdata") indent = 8 @@ -236,16 +235,6 @@ def _generate_jax_code(self) -> None: # tuple of variable names (ids as they are unique) **_jax_variable_ids(self.model, ("p", "k", "y", "x")), **{ - # in jax model we do not need to distinguish between p (parameters) and - # k (fixed parameters) so we use a single variable combining both - "PK_SYMS": "".join( - str(strip_pysb(s)) + ", " - for s in chain(self.model.sym("p"), self.model.sym("k")) - ), - "PK_IDS": "".join( - f'"{strip_pysb(s)}", ' - for s in chain(self.model.sym("p"), self.model.sym("k")) - ), "MODEL_NAME": self.model_name, # keep track of the API version that the model was generated with so we # can flag conflicts in the future