Skip to content

Commit

Permalink
some more passing ude test
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 2, 2024
1 parent 2672be2 commit 76d599c
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 36 deletions.
64 changes: 55 additions & 9 deletions python/sdist/amici/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __call__(self, x):
)


def tanhshrink(x: jnp.ndarray) -> jnp.ndarray:
return x - jnp.tanh(x)


def generate_equinox(ml_model: MLModel, filename: Path | str):
filename = Path(filename)
layer_indent = 12
Expand Down Expand Up @@ -55,6 +59,14 @@ def generate_equinox(ml_model: MLModel, filename: Path | str):
]
)[node_indent:],
"INPUT": ", ".join([f"'{inp.input_id}'" for inp in ml_model.inputs]),
"OUTPUT": ", ".join(
[
f"'{arg}'"
for arg in next(
node for node in ml_model.forward if node.op == "output"
).args
]
),
"N_LAYERS": len(ml_model.layers),
}

Expand Down Expand Up @@ -82,8 +94,19 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
"InstanceNorm3d": "eqx.nn.LayerNorm",
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
"Flatten": "Flatten",
"Flatten": "amici.jax.nn.Flatten",
}
if layer.layer_type.startswith(("BatchNorm", "AlphaDropout")):
raise NotImplementedError(
f"{layer.layer_type} layers currently not supported"
)
if layer.layer_type.startswith("MaxPool") and "dilation" in layer.args:
raise NotImplementedError("MaxPool layers with dilation not supported")
if layer.layer_type.startswith("Dropout") and "inplace" in layer.args:
raise NotImplementedError("Dropout layers with inplace not supported")
if layer.layer_type == "Bilinear":
raise NotImplementedError("Bilinear layers not supported")

kwarg_map = {
"Linear": {
"bias": "use_bias",
Expand All @@ -106,11 +129,18 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
"affine": "elementwise_affine",
"num_features": "shape",
},
"LayerNorm": {
"affine": "elementwise_affine",
"normalized_shape": "shape",
},
}
kwarg_ignore = {
"InstanceNorm1d": ("track_running_stats", "momentum"),
"InstanceNorm2d": ("track_running_stats", "momentum"),
"InstanceNorm3d": ("track_running_stats", "momentum"),
"BatchNorm1d": ("track_running_stats", "momentum"),
"BatchNorm2d": ("track_running_stats", "momentum"),
"BatchNorm3d": ("track_running_stats", "momentum"),
"Dropout1d": ("inplace",),
"Dropout2d": ("inplace",),
}
Expand All @@ -120,7 +150,15 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
if k not in kwarg_ignore.get(layer.layer_type, ())
]
# add key for initialization
if layer.layer_type in ("Linear", "Conv1d", "Conv2d", "Conv3d"):
if layer.layer_type in (
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
):
kwargs += [f"key=keys[{ilayer}]"]
type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}")
layer_str = f"{type_str}({', '.join(kwargs)})"
Expand All @@ -141,20 +179,28 @@ def _generate_forward(node: Node, indent, layer_type=str) -> str:

if node.op == "call_module":
fun_str = f"self.layers['{node.target}']"
if layer_type.startswith(("InstanceNorm", "Conv", "Linear")):
if layer_type.startswith(
("InstanceNorm", "Conv", "Linear", "LayerNorm")
):
if layer_type in ("LayerNorm", "InstanceNorm"):
dims = f"len({fun_str}.shape)+1"
if layer_type == "Linear":
dims = 1
if layer_type.endswith(("1d",)):
dims = 2
elif layer_type.endswith(("2d",)):
if layer_type.endswith(("1d",)):
dims = 3
elif layer_type.endswith("3d"):
elif layer_type.endswith(("2d",)):
dims = 4
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims + 1} else {fun_str})"
elif layer_type.endswith("3d"):
dims = 5
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})"

if node.op in ("call_function", "call_method"):
map_fun = {
"hardtanh": "jax.nn.hard_tanh",
"hardsigmoid": "jax.nn.hard_sigmoid",
"hardswish": "jax.nn.hard_swish",
"tanhshrink": "amici.jax.nn.tanhshrink",
"softsign": "jax.nn.soft_sign",
}
if node.target == "hardtanh":
if node.kwargs.pop("min_val", -1.0) != -1.0:
Expand All @@ -172,7 +218,7 @@ def _generate_forward(node: Node, indent, layer_type=str) -> str:
f"{k}={v}" for k, v in node.kwargs.items() if k not in ("inplace",)
]
if layer_type.startswith(("Dropout",)):
kwargs += ["inference=inference", "key=key"]
kwargs += ["key=key"]
kwargs_str = ", ".join(kwargs)
if node.op in ("call_module", "call_function", "call_method"):
return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})"
Expand Down
6 changes: 4 additions & 2 deletions python/sdist/amici/jax/nn.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import jax.nn
import jax.random as jr
import jax
from amici.jax.nn import Flatten
import amici.jax.nn


class TPL_MODEL_ID(eqx.Module):
layers: dict
inputs: list[str]
outputs: list[str]

def __init__(self, key):
super().__init__()
keys = jr.split(key, TPL_N_LAYERS)
self.layers = {TPL_LAYERS}
self.inputs = [TPL_INPUT]
self.outputs = [TPL_OUTPUT]

def forward(self, input, inference=False, key=None):
def forward(self, input, key=None):
TPL_FORWARD
return output

Expand Down
40 changes: 35 additions & 5 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,34 @@ def _unscale(
def _eval_nn(self, output_par: str):
net_id = self._petab_problem.mapping_df.loc[output_par, "netId"]
nn = self.model.nns[net_id]
net_input = tuple(
jax.lax.stop_gradient(self._inputs[net_id][input_id])
for input_id in nn.inputs

model_id_map = (
self._petab_problem.mapping_df.query(f'netId == "{net_id}"')
.reset_index()
.set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID]
.to_dict()
)
return nn.forward(*net_input).squeeze()

for petab_id in model_id_map.values():
if petab_id in self.model.state_ids:
raise NotImplementedError(
"State variables as inputs to neural networks are not supported"
)

net_input = jnp.array(
[
jax.lax.stop_gradient(self._inputs[net_id][model_id])
if model_id in self._inputs[net_id]
else self.get_petab_parameter_by_id(petab_id)
if petab_id in self.parameter_ids
else self._petab_problem.parameter_df.loc[
petab_id, petab.NOMINAL_VALUE
]
for model_id, petab_id in model_id_map.items()
if model_id.startswith("input")
]
)
return nn.forward(net_input).squeeze()

def load_parameters(
self, simulation_condition: str
Expand All @@ -347,10 +370,17 @@ def load_parameters(
Parameters for the simulation condition.
"""
mapping = self._parameter_mappings[simulation_condition]

nn_output_pars = self._petab_problem.mapping_df[
self._petab_problem.mapping_df[
petab.MODEL_ENTITY_ID
].str.startswith("output")
].index

p = jnp.array(
[
self._eval_nn(pname)
if pname in self._petab_problem.mapping_df.index
if pname in nn_output_pars
else pval
if isinstance(pval := mapping.map_sim_var[pname], Number)
else self.get_petab_parameter_by_id(pval)
Expand Down
8 changes: 7 additions & 1 deletion python/sdist/amici/petab/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ def get_states_in_condition_table(

species_check_funs = {
MODEL_TYPE_SBML: lambda x: _element_is_sbml_state(
petab_problem.sbml_model, x
petab_problem.sbml_model,
x, # v1
)
if isinstance(petab_problem, petab.Problem)
else lambda x: _element_is_sbml_state(
petab_problem.model.sbml_model,
x, # v2
),
MODEL_TYPE_PYSB: lambda x: _element_is_pysb_pattern(
petab_problem.model.model, x
Expand Down
Loading

0 comments on commit 76d599c

Please sign in to comment.