diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 4e4ca3a5b1..05527bfb46 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -8,7 +8,7 @@ import amici import amici.amici as amici_swig - +from amici.amici import _get_ptr from . import numpy from .logging import get_logger @@ -53,36 +53,6 @@ def _capture_cstdout(): yield -def _get_ptr( - obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData], -) -> Union[ - "amici_swig.Model", - "amici_swig.ExpData", - "amici_swig.Solver", - "amici_swig.ReturnData", -]: - """ - Convenience wrapper that returns the smart pointer pointee, if applicable - - :param obj: - Potential smart pointer - - :returns: - Non-smart pointer - """ - if isinstance( - obj, - ( - amici_swig.ModelPtr, - amici_swig.ExpDataPtr, - amici_swig.SolverPtr, - amici_swig.ReturnDataPtr, - ), - ): - return obj.get() - return obj - - def runAmiciSimulation( model: AmiciModel, solver: AmiciSolver, diff --git a/swig/amici.i b/swig/amici.i index 46a58f8365..645cd043b5 100644 --- a/swig/amici.i +++ b/swig/amici.i @@ -353,6 +353,8 @@ if sys.platform == 'win32' and (dll_dirs := os.environ.get('AMICI_DLL_DIRS')): // import additional types for typehints // also import np for use in __repr__ functions %pythonbegin %{ +from __future__ import annotations + from typing import TYPE_CHECKING, Iterable, Sequence import numpy as np if TYPE_CHECKING: @@ -368,4 +370,35 @@ __all__ = [ if not x.startswith('_') and x not in {"np", "sys", "os", "numpy", "IntEnum", "enum", "pi", "TYPE_CHECKING", "Iterable", "Sequence"} ] + + +def _get_ptr( + obj: Union[AmiciModel, AmiciExpData, AmiciSolver, AmiciReturnData], +) -> Union[ + Model, + ExpData, + Solver, + ReturnData, +]: + """ + Convenience wrapper that returns the smart pointer pointee, if applicable + + :param obj: + Potential smart pointer + + :returns: + Non-smart pointer + """ + if isinstance( + obj, + ( + ModelPtr, + ExpDataPtr, + SolverPtr, + ReturnDataPtr, + ), + ): + return obj.get() + return obj + %}