Skip to content

Commit

Permalink
Merge branch 'develop' into fix_install_mac
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Mar 6, 2024
2 parents f0213ea + b8be5a0 commit 1ac844d
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TypeHintFixer(ast.NodeTransformer):
"ptrdiff_t": ast.Name("int"),
"size_t": ast.Name("int"),
"bool": ast.Name("bool"),
"boolean": ast.Name("bool"),
"std::unique_ptr< amici::Solver >": ast.Constant("Solver"),
"amici::InternalSensitivityMethod": ast.Constant(
"InternalSensitivityMethod"
Expand All @@ -40,8 +41,10 @@ class TypeHintFixer(ast.NodeTransformer):
"SteadyStateSensitivityMode"
),
"amici::realtype": ast.Name("float"),
"DoubleVector": ast.Constant("Sequence[float]"),
"DoubleVector": ast.Name("Sequence[float]"),
"BoolVector": ast.Name("Sequence[bool]"),
"IntVector": ast.Name("Sequence[int]"),
"StringVector": ast.Name("Sequence[str]"),
"std::string": ast.Name("str"),
"std::string const &": ast.Name("str"),
"std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"),
Expand All @@ -53,6 +56,8 @@ class TypeHintFixer(ast.NodeTransformer):
}

def visit_FunctionDef(self, node):
self._annotation_from_docstring(node)

# Has a return type annotation?
if node.returns:
node.returns = self._new_annot(node.returns.value)
Expand Down Expand Up @@ -103,6 +108,42 @@ def _new_annot(self, old_annot: str):

return ast.Constant(old_annot)

def _annotation_from_docstring(self, node: ast.FunctionDef):
"""Add annotations based on docstring.
If any argument or return type of the function is not annotated, but
the corresponding docstring contains a type hint, the type hint is used
as the annotation.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:
# skip overloaded methods
return

docstring = docstring.split("\n")
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if match := re.match(r"\W*:rtype:\W*(.+)", line):
node.returns = ast.Constant(match.group(1))
lines_to_remove.add(line_no)

if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line):
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
lines_to_remove.add(line_no)

if lines_to_remove:
# Update docstring with type annotations removed
assert isinstance(node.body[0].value, ast.Constant)
new_docstring = "\n".join(
line
for line_no, line in enumerate(docstring)
if line_no not in lines_to_remove
)
node.body[0].value = ast.Str(new_docstring)


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
Expand Down

0 comments on commit 1ac844d

Please sign in to comment.