From 6f12f85ae59657790dc9b19faabe7d79f95b333e Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 7 Mar 2024 21:18:31 +0100 Subject: [PATCH] Fix type annotations in swig-wrappers (again) (#2365) Some things were missing in https://github.com/AMICI-dev/AMICI/pull/2344 --- python/sdist/amici/swig.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/sdist/amici/swig.py b/python/sdist/amici/swig.py index 81a030ba3d..fbc486c301 100644 --- a/python/sdist/amici/swig.py +++ b/python/sdist/amici/swig.py @@ -59,7 +59,7 @@ def visit_FunctionDef(self, node): self._annotation_from_docstring(node) # Has a return type annotation? - if node.returns: + if node.returns and isinstance(node.returns, ast.Constant): node.returns = self._new_annot(node.returns.value) # Has arguments? @@ -112,8 +112,11 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): """Add annotations based on docstring. If any argument or return type of the function is not annotated, but - the corresponding docstring contains a type hint, the type hint is used - as the annotation. + the corresponding docstring contains a type hint (``:rtype:`` or + ``:type:``), the type hint is used as the annotation. + + Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of + ``:type solver: Solver``. Those need special treatment. """ docstring = ast.get_docstring(node, clean=False) if not docstring or "*Overload 1:*" in docstring: @@ -124,11 +127,19 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): lines_to_remove = set() for line_no, line in enumerate(docstring): - if match := re.match(r"\W*:rtype:\W*(.+)", line): + if ( + match := re.match( + r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line + ) + ) and not match.group(1).startswith(":"): node.returns = ast.Constant(match.group(1)) lines_to_remove.add(line_no) - if match := re.match(r"\W*:type:\W*(\w+):\W*(.+)", line): + if ( + match := re.match( + r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line + ) + ) and not match.group(1).startswith(":"): for arg in node.args.args: if arg.arg == match.group(1): arg.annotation = ast.Constant(match.group(2))