Skip to content

Commit

Permalink
optional
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 11, 2024
1 parent 2fd6da8 commit e716487
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
22 changes: 17 additions & 5 deletions petab/v1/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@

__all__ = ["priors_to_measurements"]

# TODO: does anybody really rely on the old behavior?
USE_PROPER_TRUNCATION = True


class Prior:
"""A PEtab parameter prior.
Expand All @@ -61,6 +58,15 @@ class Prior:
on the `parameter_scale` scale).
:param bounds: The untransformed bounds of the sample (lower, upper).
:param transformation: The transformation of the distribution.
:param bounds_truncate: Whether the generated prior will be truncated
at the bounds.
If ``True``, the probability density will be rescaled
accordingly and the sample is generated from the truncated
distribution.
If ``False``, the probability density will not account for the
bounds, but any parameter samples outside the bounds will be set to
the value of the closest bound. In this case, the PDF might not match
the sample.
"""

def __init__(
Expand All @@ -69,6 +75,7 @@ def __init__(
parameters: tuple,
bounds: tuple = None,
transformation: str = C.LIN,
bounds_truncate: bool = True,
):
if transformation not in C.PARAMETER_SCALES:
raise ValueError(
Expand All @@ -91,7 +98,7 @@ def __init__(
self._bounds = bounds
self._transformation = transformation

truncation = bounds if USE_PROPER_TRUNCATION else None
truncation = bounds if bounds_truncate else None
if truncation is not None:
# for uniform, we don't want to implement truncation and just
# adapt the distribution parameters
Expand Down Expand Up @@ -235,12 +242,16 @@ def neglogprior(self, x):

@staticmethod
def from_par_dict(
d, type_=Literal["initialization", "objective"]
d,
type_=Literal["initialization", "objective"],
bounds_truncate: bool = True,
) -> Prior:
"""Create a distribution from a row of the parameter table.
:param d: A dictionary representing a row of the parameter table.
:param type_: The type of the distribution.
:param bounds_truncate: Whether the generated prior will be truncated
at the bounds.
:return: A distribution object.
"""
dist_type = d.get(f"{type_}PriorType", C.PARAMETER_SCALE_UNIFORM)
Expand Down Expand Up @@ -268,6 +279,7 @@ def from_par_dict(
parameters=params,
bounds=(d[C.LOWER_BOUND], d[C.UPPER_BOUND]),
transformation=pscale,
bounds_truncate=bounds_truncate,
)


Expand Down
4 changes: 3 additions & 1 deletion tests/v1/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def apply_parameter_values(row):
]
priors = [
Prior.from_par_dict(
petab_problem_priors.parameter_df.loc[par_id], type_="objective"
petab_problem_priors.parameter_df.loc[par_id],
type_="objective",
bounds_truncate=False,
)
for par_id in parameter_ids
]
Expand Down

0 comments on commit e716487

Please sign in to comment.