Skip to content

Commit

Permalink
Add SwigPtrView fields to dir() (#2244)
Browse files Browse the repository at this point in the history
This way, `dir(SwigPtrView(...))` will show the available fields and sufficiently smart IDEs will show them for code completion.
  • Loading branch information
dweindl authored Jan 2, 2024
1 parent 409ae3d commit da93984
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import collections
import copy
import itertools
from typing import Dict, Iterator, List, Literal, Union

import amici
Expand Down Expand Up @@ -164,6 +165,13 @@ def __eq__(self, other):
return False
return self._swigptr == other._swigptr

def __dir__(self):
return sorted(
set(
itertools.chain(dir(super()), self.__dict__, self._field_names)
)
)


class ReturnDataView(SwigPtrView):
"""
Expand Down Expand Up @@ -237,7 +245,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]):
if not isinstance(rdata, (ReturnDataPtr, ReturnData)):
raise TypeError(
f"Unsupported pointer {type(rdata)}, must be"
f"amici.ExpDataPtr!"
f"amici.ReturnDataPtr or amici.ReturnData!"
)
self._field_dimensions = {
"ts": [rdata.nt],
Expand Down Expand Up @@ -288,7 +296,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]):
"numerrtestfailsB": [rdata.nt],
"numnonlinsolvconvfailsB": [rdata.nt],
}
super(ReturnDataView, self).__init__(rdata)
super().__init__(rdata)

def __getitem__(
self, item: str
Expand Down Expand Up @@ -406,7 +414,7 @@ def __init__(self, edata: Union[ExpDataPtr, ExpData]):
edata.observedDataStdDev = edata.getObservedDataStdDev()
edata.observedEvents = edata.getObservedEvents()
edata.observedEventsStdDev = edata.getObservedEventsStdDev()
super(ExpDataView, self).__init__(edata)
super().__init__(edata)


def _field_as_numpy(
Expand Down

0 comments on commit da93984

Please sign in to comment.