Skip to content

Commit

Permalink
Migrate modelfun from boilerdata
Browse files Browse the repository at this point in the history
  • Loading branch information
blakeNaccarato committed Sep 7, 2023
1 parent 700afa1 commit b45a849
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions .tools/requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Latest working requirements of your package, should be ahead of pyproject.toml.
dill==0.3.7
IPython[notebook]==8.14.0
pandas[hdf5,performance]==2.0.2
pandas-stubs~=2.0.2
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ license = { file = "LICENSE" }
requires-python = ">=3.11"
classifiers = ["License :: OSI Approved :: MIT License"]
dependencies = [
"dill>=0.3.7",
"IPython[notebook]>=8.14.0",
"pandas[hdf5,performance]>=2.0.2",
"ploomber-engine>=0.0.30",
Expand Down
38 changes: 38 additions & 0 deletions src/boilercore/modelfun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import warnings
from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import Any

import dill
import numpy as np


def get_model(model: Path):
"""Unpickle the model function for fitting data."""
file_bytes = Path(model).read_bytes()
with warnings.catch_warnings():
warnings.simplefilter("ignore", dill.UnpicklingWarning)
unpickled_model = dill.loads(file_bytes)
return unpickled_model.basic, fix_model(unpickled_model.for_ufloat)


def fix_model(f) -> Callable[..., Any]:
"""Fix edge-cases of lambdify where all inputs must be arrays.
See the notes section in the link below where it says, "However, in some cases
the generated function relies on the input being a numpy array."
https://docs.sympy.org/latest/modules/utilities/lambdify.html#sympy.utilities.lambdify.lambdify
"""

@wraps(f)
def wrapper(*args, **kwargs):
result = f(
*(np.array(arg) for arg in args),
**{k: np.array(v) for k, v in kwargs.items()},
)

return result if result.size > 1 else result.item()

return wrapper

0 comments on commit b45a849

Please sign in to comment.