Skip to content

Commit

Permalink
update template
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 16, 2024
1 parent 404d82e commit e399f4c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 56 deletions.
21 changes: 0 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,6 @@ repos:
args: [--allow-multiple-documents]
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.7
hooks:
# Run the linter.
- id: ruff
args:
- --fix
- --config
- python/sdist/pyproject.toml

# Run the formatter.
- id: ruff-format
args:
- --config
- python/sdist/pyproject.toml

- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
hooks:
- id: pyupgrade
args: ["--py310-plus"]

exclude: '^(ThirdParty|models)/'
64 changes: 29 additions & 35 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
from interpax import interp1d

from amici.jax import JAXModel
from amici.jax.model import JAXModel


class JAXModel_TPL_MODEL_NAME(JAXModel):
Expand All @@ -11,35 +11,32 @@ def __init__(self):
@staticmethod
def xdot(t, x, args):

p, k, tcl = args
pk, tcl = args

TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, p, k, tcl)
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, pk, tcl)

TPL_XDOT_EQ

return TPL_XDOT_RET

@staticmethod
def _w(t, x, p, k, tcl):
def _w(t, x, pk, tcl):

TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl

TPL_W_EQ

return TPL_W_RET

@staticmethod
def x0(p, k):
def x0(pk):

TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_PK_SYMS = pk

TPL_X0_EQ

Expand All @@ -65,55 +62,48 @@ def x_rdata(x, tcl):
return TPL_X_RDATA_RET

@staticmethod
def tcl(x, p, k):
def tcl(x, pk):

TPL_X_RDATA_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_PK_SYMS = pk

TPL_TOTAL_CL_EQ

return TPL_TOTAL_CL_RET

@staticmethod
def y(t, x, p, k, tcl):
def y(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, p, k, tcl)
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)

TPL_Y_EQ

return TPL_Y_RET

@staticmethod
def sigmay(y, p, k):
def sigmay(self, y, pk):
TPL_PK_SYMS = pk

TPL_Y_SYMS = y
TPL_P_SYMS = p
TPL_K_SYMS = k

TPL_SIGMAY_EQ

return TPL_SIGMAY_RET

@staticmethod
def Jy(y, my, sigmay):

def llh(self, t, x, pk, tcl, my, iy):
y = self.y(t, x, pk, tcl)
TPL_Y_SYMS = y
TPL_MY_SYMS = my
sigmay = self.sigmay(y, pk)
TPL_SIGMAY_SYMS = sigmay

TPL_JY_EQ

return TPL_JY_RET

@property
def parameter_ids(self):
return TPL_P_IDS

@property
def fixed_parameter_ids(self):
return TPL_K_IDS
return jnp.array([
TPL_JY_RET.at[iy].get(),
y.at[iy].get(),
sigmay.at[iy].get()
])

@property
def observable_ids(self):
Expand All @@ -122,3 +112,7 @@ def observable_ids(self):
@property
def state_ids(self):
return TPL_X_IDS

@property
def parameter_ids(self):
return TPL_PK_IDS

0 comments on commit e399f4c

Please sign in to comment.