diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index b84e52cc2b..93b04603be 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -6,6 +6,7 @@ import collections import copy +import itertools from typing import Dict, Iterator, List, Literal, Union import amici @@ -164,6 +165,13 @@ def __eq__(self, other): return False return self._swigptr == other._swigptr + def __dir__(self): + return sorted( + set( + itertools.chain(dir(super()), self.__dict__, self._field_names) + ) + ) + class ReturnDataView(SwigPtrView): """ @@ -237,7 +245,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): if not isinstance(rdata, (ReturnDataPtr, ReturnData)): raise TypeError( f"Unsupported pointer {type(rdata)}, must be" - f"amici.ExpDataPtr!" + f"amici.ReturnDataPtr or amici.ReturnData!" ) self._field_dimensions = { "ts": [rdata.nt], @@ -288,7 +296,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): "numerrtestfailsB": [rdata.nt], "numnonlinsolvconvfailsB": [rdata.nt], } - super(ReturnDataView, self).__init__(rdata) + super().__init__(rdata) def __getitem__( self, item: str @@ -406,7 +414,7 @@ def __init__(self, edata: Union[ExpDataPtr, ExpData]): edata.observedDataStdDev = edata.getObservedDataStdDev() edata.observedEvents = edata.getObservedEvents() edata.observedEventsStdDev = edata.getObservedEventsStdDev() - super(ExpDataView, self).__init__(edata) + super().__init__(edata) def _field_as_numpy(