-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
700afa1
commit b45a849
Showing
3 changed files
with
40 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import warnings | ||
from collections.abc import Callable | ||
from functools import wraps | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import dill | ||
import numpy as np | ||
|
||
|
||
def get_model(model: Path): | ||
"""Unpickle the model function for fitting data.""" | ||
file_bytes = Path(model).read_bytes() | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", dill.UnpicklingWarning) | ||
unpickled_model = dill.loads(file_bytes) | ||
return unpickled_model.basic, fix_model(unpickled_model.for_ufloat) | ||
|
||
|
||
def fix_model(f) -> Callable[..., Any]: | ||
"""Fix edge-cases of lambdify where all inputs must be arrays. | ||
See the notes section in the link below where it says, "However, in some cases | ||
the generated function relies on the input being a numpy array." | ||
https://docs.sympy.org/latest/modules/utilities/lambdify.html#sympy.utilities.lambdify.lambdify | ||
""" | ||
|
||
@wraps(f) | ||
def wrapper(*args, **kwargs): | ||
result = f( | ||
*(np.array(arg) for arg in args), | ||
**{k: np.array(v) for k, v in kwargs.items()}, | ||
) | ||
|
||
return result if result.size > 1 else result.item() | ||
|
||
return wrapper |