Skip to content

Commit

Permalink
Merge pull request #4 from orsinium-labs/types-from-names
Browse files Browse the repository at this point in the history
Guess return type from function name
  • Loading branch information
orsinium authored Feb 10, 2023
2 parents 3a07242 + 5e1ec01 commit 77a6666
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 21 deletions.
2 changes: 1 addition & 1 deletion infer_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@


__all__ = ['entrypoint', 'main']
__version__ = '0.3.1'
__version__ = '0.3.2'
22 changes: 15 additions & 7 deletions infer_types/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import NoReturn, TextIO

from ._extractors import extractors
from ._format import format_code
from ._inferno import Inferno

Expand All @@ -27,6 +28,7 @@ class Config:
functions: bool # allow annotating functions
assumptions: bool # allow astypes to make assumptions
dry: bool # do not write changes in files
only: frozenset[str] # run only the these extractors
stream: TextIO # stdout


Expand All @@ -37,6 +39,7 @@ def add_annotations(root: Path, config: Config) -> None:
methods=config.methods,
functions=config.functions,
assumptions=config.assumptions,
only=config.only,
)
for path in root.iterdir():
if path.is_dir():
Expand All @@ -60,7 +63,11 @@ def main(argv: list[str], stream: TextIO) -> int:
parser = ArgumentParser()
parser.add_argument(
'dir', type=Path, default=Path(),
help='path to directory with the source code to analyze'
help='path to the directory with the source code to analyze',
)
parser.add_argument(
'--only', nargs='*', choices=sorted(name for name, _ in extractors),
help='list of extractors to run (all by default)',
)
parser.add_argument(
'--format', action='store_true',
Expand Down Expand Up @@ -104,15 +111,16 @@ def main(argv: list[str], stream: TextIO) -> int:
)
args = parser.parse_args(argv)
config = Config(
format=args.format,
skip_tests=args.skip_tests,
skip_migrations=args.skip_migrations,
assumptions=not args.no_assumptions,
dry=args.dry,
exit_on_failure=args.exit_on_failure or args.pdb,
format=args.format,
functions=not args.no_functions,
imports=not args.no_imports,
methods=not args.no_methods,
functions=not args.no_functions,
assumptions=not args.no_assumptions,
dry=args.dry,
only=args.only,
skip_migrations=args.skip_migrations,
skip_tests=args.skip_tests,
stream=stream,
)
try:
Expand Down
59 changes: 49 additions & 10 deletions infer_types/_extractors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import ast
import builtins
from collections import deque
from types import MappingProxyType
from typing import Callable, Iterator

import astroid
Expand All @@ -11,20 +13,38 @@


Extractor = Callable[[astroid.FunctionDef], Type]
extractors: list[Extractor] = []
extractors: list[tuple[str, Extractor]] = []


def register(extractor: Extractor) -> Extractor:
extractors.append(extractor)
return extractor
KNOWN_NAMES = MappingProxyType({
'dumps': 'str',
'exists': 'bool',
'contains': 'bool',
'count': 'int',
'size': 'int',
})
REMOVE_PREFIXES = ('as_', 'to_', 'get_')
BOOL_PREFIXES = ('is_', 'has_', 'should_', 'can_', 'will_', 'supports_')


def get_return_type(func_node: astroid.FunctionDef) -> Type | None:
def register(name: str) -> Callable[[Extractor], Extractor]:
def callback(extractor: Extractor) -> Extractor:
extractors.append((name, extractor))
return extractor
return callback


def get_return_type(
func_node: astroid.FunctionDef,
names: frozenset[str],
) -> Type | None:
"""
Recursively walk the given body, find all return stmts,
and infer their type. The result is a union of these types.
"""
for extractor in extractors:
for name, extractor in extractors:
if names and name not in names:
continue
ret_type = extractor(func_node)
if not ret_type.unknown:
return ret_type
Expand All @@ -44,7 +64,7 @@ def walk(func_node: astroid.FunctionDef) -> Iterator[astroid.NodeNG]:
yield node


@register
@register(name='astypes')
def _extract_astypes(func_node: astroid.FunctionDef) -> Type:
result = Type.new('')
for node in walk(func_node):
Expand All @@ -64,7 +84,7 @@ def _extract_astypes(func_node: astroid.FunctionDef) -> Type:
return result


@register
@register(name='inherit')
def _extract_inherit_method(func_node: astroid.FunctionDef) -> Type:
for node in func_node.node_ancestors():
if isinstance(node, astroid.ClassDef):
Expand Down Expand Up @@ -105,17 +125,36 @@ def _extract_inherit_method(func_node: astroid.FunctionDef) -> Type:
return Type.new('')


@register
@register(name='yield')
def _extract_yield(func_node: astroid.FunctionDef) -> Type:
for node in walk(func_node):
if isinstance(node, (astroid.Yield, astroid.YieldFrom)):
return Type.new('Iterator', module='typing')
return Type.new('')


@register
@register(name='none')
def _extract_no_return(func_node: astroid.FunctionDef) -> Type:
for node in walk(func_node):
if isinstance(node, astroid.Return) and node.value is not None:
return Type.new('')
return Type.new('None')


@register(name='name')
def _extract_from_name(func_node: astroid.FunctionDef) -> Type:
"""Try to guess the return type based on the function name.
"""
name: str = func_node.name
name = name.lstrip('_')
# TODO(@orsinium): use str.removeprefix when migrating to 3.9
for prefix in REMOVE_PREFIXES:
if name.startswith(prefix):
name = name[len(prefix):]
if name.startswith(BOOL_PREFIXES):
return Type.new('bool')
if name.endswith('_at'):
return Type.new('datetime', module='datetime')
if hasattr(builtins, name):
return Type.new(name)
return Type.new(KNOWN_NAMES.get(name, ''))
5 changes: 3 additions & 2 deletions infer_types/_inferno.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import Iterator
Expand All @@ -24,6 +24,7 @@ class Inferno:
methods: bool = True
functions: bool = True
assumptions: bool = True
only: frozenset[str] = field(default_factory=frozenset)

def transform(self, path: Path) -> str:
source = path.read_text()
Expand Down Expand Up @@ -70,7 +71,7 @@ def _get_transforms_for_node(self, node: astroid.NodeNG) -> Iterator[Transformat
def _infer_sig(self, node: astroid.FunctionDef) -> FSig | None:
if node.returns is not None:
return None
return_type = get_return_type(node)
return_type = get_return_type(node, names=self.only)
if return_type is None:
return None
if not self.assumptions and return_type.assumptions:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ keywords = [
]
dependencies = [
"astroid",
"astypes>=0.2.1",
"astypes>=0.2.3",
"typeshed_client",
]

Expand All @@ -43,6 +43,7 @@ lint = [
"types-protobuf",
"unify",
]

[project.urls]
Source = "https://github.com/orsinium-labs/infer-types"

Expand Down
26 changes: 26 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,29 @@ def f2(x):
code = main([str(tmp_path), '--no-assumptions'], stream)
assert code == 0
assert source_file.read_text() == dedent(expected)


def test_only(tmp_path: Path):
given = """
def f1():
return 1
def is_used(x):
return x
"""
expected = """
def f1():
return 1
def is_used(x) -> bool:
return x
"""

# prepare files and dirs
source_file = tmp_path / 'example.py'
source_file.write_text(dedent(given))
# call the CLI
stream = StringIO()
code = main([str(tmp_path), '--only', 'name'], stream)
assert code == 0
assert source_file.read_text() == dedent(expected)
32 changes: 32 additions & 0 deletions tests/test_inferno.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,35 @@ def f() -> {e_type}:
return {g_expr}
"""
assert transform(given) == dedent(expected)


@pytest.mark.parametrize('f_name, e_type', [
('is_user', 'bool'),
('has_access', 'bool'),
('exists', 'bool'),
('contains', 'bool'),
('size', 'int'),
('get_size', 'int'),
('as_dict', 'dict'),
('to_dict', 'dict'),
('dumps', 'str'),
('should_fail', 'bool'),
('can_fail', 'bool'),
('will_fail', 'bool'),
('supports_pickups', 'bool'),
('created_at', 'datetime'),
('updated_at', 'datetime'),
])
def test_infer_type_from_function_name(transform, f_name, e_type):
given = f"""
def {f_name}(x):
return x
"""
expected = dedent(f"""
def {f_name}(x) -> {e_type}:
return x
""")
actual = transform(given)
if e_type == 'datetime':
expected = expected.replace('def', 'from datetime import datetime\ndef')
assert actual == expected

0 comments on commit 77a6666

Please sign in to comment.