Skip to content

Commit

Permalink
Merge branch 'develop' into jax_sciml
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 2, 2024
2 parents 0605b78 + f3a97c2 commit d166f03
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 54 deletions.
15 changes: 0 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,7 @@ repos:
args: [--allow-multiple-documents]
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.7
hooks:
# Run the linter.
- id: ruff
args:
- --fix
- --config
- python/sdist/pyproject.toml

# Run the formatter.
- id: ruff-format
args:
- --config
- python/sdist/pyproject.toml

- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
Expand Down
1 change: 1 addition & 0 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ struct ModelStateDerived {
dwdx.set_ctx(sunctx_);
}
sspl_.set_ctx(sunctx_);
x_pos_tmp_.set_ctx(sunctx_);
dwdw_.set_ctx(sunctx_);
dJydy_dense_.set_ctx(sunctx_);
}
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

TPL_NET_IMPORTS


class JAXModel_TPL_MODEL_NAME(JAXModel):
api_version = TPL_MODEL_API_VERSION

def __init__(self):
self.jax_py_file = Path(__file__).resolve()
self.nns = {TPL_NETS}

super().__init__()

def _xdot(self, t, x, args):
Expand Down
39 changes: 39 additions & 0 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,45 @@ def load(cls, directory: Path):
with open(directory / "parameters.pkl", "rb") as f:
return eqx.tree_deserialise_leaves(f, problem)

def save(self, directory: Path):
"""
Save the problem to a directory.
:param directory:
Directory to save the problem to.
"""
self._petab_problem.to_files(
prefix_path=directory,
model_file="model",
condition_file="conditions.tsv",
measurement_file="measurements.tsv",
parameter_file="parameters.tsv",
observable_file="observables.tsv",
yaml_file="problem.yaml",
)
shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py")
with open(directory / "parameters.pkl", "wb") as f:
eqx.tree_serialise_leaves(f, self)

@classmethod
def load(cls, directory: Path):
"""
Load a problem from a directory.
:param directory:
Directory to load the problem from.
:return:
Loaded problem instance.
"""
petab_problem = petab.Problem.from_yaml(
directory / "problem.yaml",
)
model = _module_from_path("jax", directory / "jax_py_file.py").Model()
problem = cls(model, petab_problem)
with open(directory / "parameters.pkl", "rb") as f:
return eqx.tree_deserialise_leaves(f, problem)

def _get_parameter_mappings(
self, simulation_conditions: pd.DataFrame
) -> dict[str, ParameterMappingForCondition]:
Expand Down
126 changes: 94 additions & 32 deletions python/sdist/amici/petab/parameter_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
PARAMETER_SCALE,
PREEQUILIBRATION_CONDITION_ID,
SIMULATION_CONDITION_ID,
NOMINAL_VALUE,
ESTIMATE,
)
from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML
from sympy.abc import _clash
Expand All @@ -60,9 +62,9 @@ class ParameterMappingForCondition:
"""Parameter mapping for condition.
Contains mappings for free parameters, fixed parameters, and fixed
preequilibration parameters, both for parameters and scales.
pre-equilibration parameters, both for parameters and scales.
In the scale mappings, for each simulation parameter the scale
In the scale mappings, for each simulation parameter, the scale
on which the value is passed (and potentially gradients are to be
returned) is given. In the parameter mappings, for each simulation
parameter a corresponding optimization parameter (or a numeric value)
Expand All @@ -76,9 +78,9 @@ class ParameterMappingForCondition:
:param scale_map_sim_var:
Scales for free simulation parameters.
:param map_preeq_fix:
Mapping for fixed preequilibration parameters.
Mapping for fixed pre-equilibration parameters.
:param scale_map_preeq_fix:
Scales for fixed preequilibration parameters.
Scales for fixed pre-equilibration parameters.
:param map_sim_fix:
Mapping for fixed simulation parameters.
:param scale_map_sim_fix:
Expand Down Expand Up @@ -177,7 +179,7 @@ def __len__(self):
def append(
self, parameter_mapping_for_condition: ParameterMappingForCondition
):
"""Append a condition specific parameter mapping."""
"""Append a condition-specific parameter mapping."""
self.parameter_mappings.append(parameter_mapping_for_condition)

def __repr__(self):
Expand Down Expand Up @@ -307,9 +309,10 @@ def unscale_parameters_dict(

def create_parameter_mapping(
petab_problem: petab.Problem,
simulation_conditions: pd.DataFrame | list[dict],
simulation_conditions: pd.DataFrame | list[dict] | None,
scaled_parameters: bool,
amici_model: AmiciModel | None = None,
fill_fixed_parameters: bool = True,
**parameter_mapping_kwargs,
) -> ParameterMapping:
"""Generate AMICI specific parameter mapping.
Expand All @@ -325,11 +328,14 @@ def create_parameter_mapping(
are assumed to be in linear scale.
:param amici_model:
AMICI model.
:param fill_fixed_parameters:
Whether to fill in nominal values for fixed parameters
(estimate=0 in the parameters table).
To allow changing fixed PEtab problem parameters,
use ``fill_fixed_parameters=False``.
:param parameter_mapping_kwargs:
Optional keyword arguments passed to
:func:`petab.get_optimization_to_simulation_parameter_mapping`.
To allow changing fixed PEtab problem parameters (``estimate=0``),
use ``fill_fixed_parameters=False``.
:return:
List of the parameter mappings.
"""
Expand Down Expand Up @@ -381,6 +387,7 @@ def create_parameter_mapping(
mapping_df=petab_problem.mapping_df,
model=petab_problem.model,
simulation_conditions=simulation_conditions,
fill_fixed_parameters=fill_fixed_parameters,
**dict(
default_parameter_mapping_kwargs, **parameter_mapping_kwargs
),
Expand All @@ -392,7 +399,11 @@ def create_parameter_mapping(
simulation_conditions.iterrows(), prelim_parameter_mapping, strict=True
):
mapping_for_condition = create_parameter_mapping_for_condition(
prelim_mapping_for_condition, condition, petab_problem, amici_model
prelim_mapping_for_condition,
condition,
petab_problem,
amici_model,
fill_fixed_parameters=fill_fixed_parameters,
)
parameter_mapping.append(mapping_for_condition)

Expand All @@ -404,8 +415,9 @@ def create_parameter_mapping_for_condition(
condition: pd.Series | dict,
petab_problem: petab.Problem,
amici_model: AmiciModel | None = None,
fill_fixed_parameters: bool = True,
) -> ParameterMappingForCondition:
"""Generate AMICI specific parameter mapping for condition.
"""Generate AMICI-specific parameter mapping for a PEtab simulation.
:param parameter_mapping_for_condition:
Preliminary parameter mapping for condition.
Expand All @@ -416,10 +428,12 @@ def create_parameter_mapping_for_condition(
Underlying PEtab problem.
:param amici_model:
AMICI model.
:param fill_fixed_parameters:
Whether to fill in nominal values for fixed parameters
(estimate=0 in the parameters table).
:return:
The parameter and parameter scale mappings, for fixed
preequilibration, fixed simulation, and variable simulation
pre-equilibration, fixed simulation, and variable simulation
parameters, and then the respective scalings.
"""
(
Expand All @@ -440,10 +454,10 @@ def create_parameter_mapping_for_condition(
if len(condition_map_preeq) and len(condition_map_preeq) != len(
condition_map_sim
):
logger.debug(f"Preequilibration parameter map: {condition_map_preeq}")
logger.debug(f"Pre-equilibration parameter map: {condition_map_preeq}")
logger.debug(f"Simulation parameter map: {condition_map_sim}")
raise AssertionError(
"Number of parameters for preequilbration "
"Number of parameters for pre-equilbration "
"and simulation do not match."
)

Expand All @@ -455,8 +469,8 @@ def create_parameter_mapping_for_condition(
# During model generation, parameters for initial concentrations and
# respective initial assignments have been created for the
# relevant species, here we add these parameters to the parameter mapping.
# In absence of preequilibration this could also be handled via
# ExpData.x0, but in the case of preequilibration this would not allow for
# In the absence of pre-equilibration this could also be handled via
# ExpData.x0, but in the case of pre-equilibration this would not allow for
# resetting initial states.

if states_in_condition_table := get_states_in_condition_table(
Expand Down Expand Up @@ -489,10 +503,11 @@ def create_parameter_mapping_for_condition(
condition_map_preeq,
condition_scale_map_preeq,
preeq_value,
fill_fixed_parameters=fill_fixed_parameters,
)
# need to set dummy value for preeq parameter anyways, as it
# need to set a dummy value for preeq parameter anyways, as it
# is expected below (set to 0, not nan, because will be
# multiplied with indicator variable in initial assignment)
# multiplied with the indicator variable in initial assignment)
condition_map_sim[init_par_id] = 0.0
condition_scale_map_sim[init_par_id] = LIN

Expand All @@ -507,6 +522,7 @@ def create_parameter_mapping_for_condition(
condition_map_sim,
condition_scale_map_sim,
value,
fill_fixed_parameters=fill_fixed_parameters,
)
# set dummy value as above
if condition_map_preeq:
Expand Down Expand Up @@ -553,11 +569,11 @@ def create_parameter_mapping_for_condition(
condition_scale_map_sim_fix = {}

logger.debug(
"Fixed parameters preequilibration: " f"{condition_map_preeq_fix}"
"Fixed parameters pre-equilibration: " f"{condition_map_preeq_fix}"
)
logger.debug("Fixed parameters simulation: " f"{condition_map_sim_fix}")
logger.debug(
"Variable parameters preequilibration: " f"{condition_map_preeq_var}"
"Variable parameters pre-equilibration: " f"{condition_map_preeq_var}"
)
logger.debug("Variable parameters simulation: " f"{condition_map_sim_var}")

Expand All @@ -583,21 +599,46 @@ def create_parameter_mapping_for_condition(


def _set_initial_state(
petab_problem,
condition_id,
element_id,
init_par_id,
par_map,
scale_map,
value,
):
petab_problem: petab.Problem,
condition_id: str,
element_id: str,
init_par_id: str,
par_map: petab.ParMappingDict,
scale_map: petab.ScaleMappingDict,
value: str | float,
fill_fixed_parameters: bool = True,
) -> None:
"""
Update the initial value for a model entity in the parameter mapping
according to the PEtab conditions table.
:param petab_problem: The PEtab problem
:param condition_id: The current condition ID
:param element_id: Element for which to set the initial value
:param init_par_id: The parameter ID that refers to the initial value
:param par_map: Parameter value mapping
:param scale_map: Parameter scale mapping
:param value: The initial value for `element_id` in `condition_id`
:param fill_fixed_parameters:
Whether to fill in nominal values for fixed parameters
(estimate=0 in the parameters table).
"""
value = petab.to_float_if_float(value)
# NaN indicates that the initial value is to be taken from the model
# (if this is the pre-equilibration condition, or the simulation condition
# when no pre-equilibration condition is set) or is not to be reset
# (if this is the simulation condition following pre-equilibration)-
# The latter is not handled here.
if pd.isna(value):
if petab_problem.model.type_id == MODEL_TYPE_SBML:
value = _get_initial_state_sbml(petab_problem, element_id)
elif petab_problem.model.type_id == MODEL_TYPE_PYSB:
value = _get_initial_state_pysb(petab_problem, element_id)

else:
raise NotImplementedError(
f"Model type {petab_problem.model.type_id} not supported."
)
# the initial value can be a numeric value or a sympy expression
try:
value = float(value)
except (ValueError, TypeError):
Expand All @@ -618,14 +659,24 @@ def _set_initial_state(
f"defined for the condition {condition_id} in "
"the PEtab conditions table. The initial value is "
f"now set to {value}, which is the initial value "
"defined in the SBML model."
"defined in the original model."
)

par_map[init_par_id] = value
if isinstance(value, float):
# numeric initial state
scale_map[init_par_id] = petab.LIN
else:
# parametric initial state
if (
fill_fixed_parameters
and petab_problem.parameter_df is not None
and value in petab_problem.parameter_df.index
and petab_problem.parameter_df.loc[value, ESTIMATE] == 0
):
par_map[init_par_id] = petab_problem.parameter_df.loc[
value, NOMINAL_VALUE
]
scale_map[init_par_id] = petab_problem.parameter_df[
PARAMETER_SCALE
].get(value, petab.LIN)
Expand All @@ -642,7 +693,7 @@ def _subset_dict(
Collections of keys to be contained in the different subsets
:return:
subsetted dictionary
Subsetted dictionary
"""
for keys in args:
yield {key: val for (key, val) in full.items() if key in keys}
Expand All @@ -651,6 +702,11 @@ def _subset_dict(
def _get_initial_state_sbml(
petab_problem: petab.Problem, element_id: str
) -> float | sp.Basic:
"""Get the initial value of an SBML model entity.
Get the initial value of an SBML model entity (species, parameter, ...) as
defined in the model (not considering any condition table overrides).
"""
import libsbml

element = petab_problem.sbml_model.getElementBySId(element_id)
Expand Down Expand Up @@ -692,9 +748,15 @@ def _get_initial_state_sbml(
def _get_initial_state_pysb(
petab_problem: petab.Problem, element_id: str
) -> float | sp.Symbol:
"""Get the initial value of a PySB model entity.
Get the initial value of an PySB model entity as defined in the model
(not considering any condition table overrides).
"""
from pysb.pattern import match_complex_pattern

species_idx = int(re.match(r"__s(\d+)$", element_id)[1])
species_pattern = petab_problem.model.model.species[species_idx]
from pysb.pattern import match_complex_pattern

value = next(
(
Expand Down
Loading

0 comments on commit d166f03

Please sign in to comment.