From 741fed0182addc6eb57cc183ceb13c0caa69eb75 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 10 Feb 2023 12:19:26 +0100 Subject: [PATCH 1/3] guess return type from the function name --- infer_types/_extractors.py | 32 ++++++++++++++++++++++++++++++++ pyproject.toml | 3 ++- tests/test_inferno.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/infer_types/_extractors.py b/infer_types/_extractors.py index b4d2e9b..3fb0921 100644 --- a/infer_types/_extractors.py +++ b/infer_types/_extractors.py @@ -1,7 +1,9 @@ from __future__ import annotations import ast +import builtins from collections import deque +from types import MappingProxyType from typing import Callable, Iterator import astroid @@ -14,6 +16,17 @@ extractors: list[Extractor] = [] +KNOWN_NAMES = MappingProxyType({ + 'dumps': 'str', + 'exists': 'bool', + 'contains': 'bool', + 'count': 'int', + 'size': 'int', +}) +REMOVE_PREFIXES = ('as_', 'to_', 'get_') +BOOL_PREFIXES = ('is_', 'has_', 'should_', 'can_', 'will_', 'supports_') + + def register(extractor: Extractor) -> Extractor: extractors.append(extractor) return extractor @@ -119,3 +132,22 @@ def _extract_no_return(func_node: astroid.FunctionDef) -> Type: if isinstance(node, astroid.Return) and node.value is not None: return Type.new('') return Type.new('None') + + +@register +def _extract_from_name(func_node: astroid.FunctionDef) -> Type: + """Try to guess the return type based on the function name. + """ + name: str = func_node.name + name = name.lstrip('_') + # TODO(@orsinium): use str.removeprefix when migrating to 3.9 + for prefix in REMOVE_PREFIXES: + if name.startswith(prefix): + name = name[len(prefix):] + if name.startswith(BOOL_PREFIXES): + return Type.new('bool') + if name.endswith('_at'): + return Type.new('datetime', module='datetime') + if hasattr(builtins, name): + return Type.new(name) + return Type.new(KNOWN_NAMES.get(name, '')) diff --git a/pyproject.toml b/pyproject.toml index fcc810b..9b4ceae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ keywords = [ ] dependencies = [ "astroid", - "astypes>=0.2.1", + "astypes>=0.2.3", "typeshed_client", ] @@ -43,6 +43,7 @@ lint = [ "types-protobuf", "unify", ] + [project.urls] Source = "https://github.com/orsinium-labs/infer-types" diff --git a/tests/test_inferno.py b/tests/test_inferno.py index 3c19355..26d8d24 100644 --- a/tests/test_inferno.py +++ b/tests/test_inferno.py @@ -275,3 +275,35 @@ def f() -> {e_type}: return {g_expr} """ assert transform(given) == dedent(expected) + + +@pytest.mark.parametrize('f_name, e_type', [ + ('is_user', 'bool'), + ('has_access', 'bool'), + ('exists', 'bool'), + ('contains', 'bool'), + ('size', 'int'), + ('get_size', 'int'), + ('as_dict', 'dict'), + ('to_dict', 'dict'), + ('dumps', 'str'), + ('should_fail', 'bool'), + ('can_fail', 'bool'), + ('will_fail', 'bool'), + ('supports_pickups', 'bool'), + ('created_at', 'datetime'), + ('updated_at', 'datetime'), +]) +def test_infer_type_from_function_name(transform, f_name, e_type): + given = f""" + def {f_name}(x): + return x + """ + expected = dedent(f""" + def {f_name}(x) -> {e_type}: + return x + """) + actual = transform(given) + if e_type == 'datetime': + expected = expected.replace('def', 'from datetime import datetime\ndef') + assert actual == expected From 31d657aee1f6dd70eb8f783b6618fa3410cbeed6 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 10 Feb 2023 13:32:04 +0100 Subject: [PATCH 2/3] add --only --- infer_types/_cli.py | 22 +++++++++++++++------- infer_types/_extractors.py | 29 ++++++++++++++++++----------- infer_types/_inferno.py | 5 +++-- tests/test_cli.py | 26 ++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/infer_types/_cli.py b/infer_types/_cli.py index 63e31a5..0ac5c71 100644 --- a/infer_types/_cli.py +++ b/infer_types/_cli.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import NoReturn, TextIO +from ._extractors import extractors from ._format import format_code from ._inferno import Inferno @@ -27,6 +28,7 @@ class Config: functions: bool # allow annotating functions assumptions: bool # allow astypes to make assumptions dry: bool # do not write changes in files + only: frozenset[str] # run only the these extractors stream: TextIO # stdout @@ -37,6 +39,7 @@ def add_annotations(root: Path, config: Config) -> None: methods=config.methods, functions=config.functions, assumptions=config.assumptions, + only=config.only, ) for path in root.iterdir(): if path.is_dir(): @@ -60,7 +63,11 @@ def main(argv: list[str], stream: TextIO) -> int: parser = ArgumentParser() parser.add_argument( 'dir', type=Path, default=Path(), - help='path to directory with the source code to analyze' + help='path to the directory with the source code to analyze', + ) + parser.add_argument( + '--only', nargs='*', choices=sorted(name for name, _ in extractors), + help='list of extractors to run (all by default)', ) parser.add_argument( '--format', action='store_true', @@ -104,15 +111,16 @@ def main(argv: list[str], stream: TextIO) -> int: ) args = parser.parse_args(argv) config = Config( - format=args.format, - skip_tests=args.skip_tests, - skip_migrations=args.skip_migrations, + assumptions=not args.no_assumptions, + dry=args.dry, exit_on_failure=args.exit_on_failure or args.pdb, + format=args.format, + functions=not args.no_functions, imports=not args.no_imports, methods=not args.no_methods, - functions=not args.no_functions, - assumptions=not args.no_assumptions, - dry=args.dry, + only=args.only, + skip_migrations=args.skip_migrations, + skip_tests=args.skip_tests, stream=stream, ) try: diff --git a/infer_types/_extractors.py b/infer_types/_extractors.py index 3fb0921..73e6107 100644 --- a/infer_types/_extractors.py +++ b/infer_types/_extractors.py @@ -13,7 +13,7 @@ Extractor = Callable[[astroid.FunctionDef], Type] -extractors: list[Extractor] = [] +extractors: list[tuple[str, Extractor]] = [] KNOWN_NAMES = MappingProxyType({ @@ -27,17 +27,24 @@ BOOL_PREFIXES = ('is_', 'has_', 'should_', 'can_', 'will_', 'supports_') -def register(extractor: Extractor) -> Extractor: - extractors.append(extractor) - return extractor +def register(name: str) -> Callable[[Extractor], Extractor]: + def callback(extractor: Extractor) -> Extractor: + extractors.append((name, extractor)) + return extractor + return callback -def get_return_type(func_node: astroid.FunctionDef) -> Type | None: +def get_return_type( + func_node: astroid.FunctionDef, + names: frozenset[str], +) -> Type | None: """ Recursively walk the given body, find all return stmts, and infer their type. The result is a union of these types. """ - for extractor in extractors: + for name, extractor in extractors: + if names and name not in names: + continue ret_type = extractor(func_node) if not ret_type.unknown: return ret_type @@ -57,7 +64,7 @@ def walk(func_node: astroid.FunctionDef) -> Iterator[astroid.NodeNG]: yield node -@register +@register(name='astypes') def _extract_astypes(func_node: astroid.FunctionDef) -> Type: result = Type.new('') for node in walk(func_node): @@ -77,7 +84,7 @@ def _extract_astypes(func_node: astroid.FunctionDef) -> Type: return result -@register +@register(name='inherit') def _extract_inherit_method(func_node: astroid.FunctionDef) -> Type: for node in func_node.node_ancestors(): if isinstance(node, astroid.ClassDef): @@ -118,7 +125,7 @@ def _extract_inherit_method(func_node: astroid.FunctionDef) -> Type: return Type.new('') -@register +@register(name='yield') def _extract_yield(func_node: astroid.FunctionDef) -> Type: for node in walk(func_node): if isinstance(node, (astroid.Yield, astroid.YieldFrom)): @@ -126,7 +133,7 @@ def _extract_yield(func_node: astroid.FunctionDef) -> Type: return Type.new('') -@register +@register(name='none') def _extract_no_return(func_node: astroid.FunctionDef) -> Type: for node in walk(func_node): if isinstance(node, astroid.Return) and node.value is not None: @@ -134,7 +141,7 @@ def _extract_no_return(func_node: astroid.FunctionDef) -> Type: return Type.new('None') -@register +@register(name='name') def _extract_from_name(func_node: astroid.FunctionDef) -> Type: """Try to guess the return type based on the function name. """ diff --git a/infer_types/_inferno.py b/infer_types/_inferno.py index 9fe8e00..c5765fb 100644 --- a/infer_types/_inferno.py +++ b/infer_types/_inferno.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import getLogger from pathlib import Path from typing import Iterator @@ -24,6 +24,7 @@ class Inferno: methods: bool = True functions: bool = True assumptions: bool = True + only: frozenset[str] = field(default_factory=frozenset) def transform(self, path: Path) -> str: source = path.read_text() @@ -70,7 +71,7 @@ def _get_transforms_for_node(self, node: astroid.NodeNG) -> Iterator[Transformat def _infer_sig(self, node: astroid.FunctionDef) -> FSig | None: if node.returns is not None: return None - return_type = get_return_type(node) + return_type = get_return_type(node, names=self.only) if return_type is None: return None if not self.assumptions and return_type.assumptions: diff --git a/tests/test_cli.py b/tests/test_cli.py index 09bdb68..0320913 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -277,3 +277,29 @@ def f2(x): code = main([str(tmp_path), '--no-assumptions'], stream) assert code == 0 assert source_file.read_text() == dedent(expected) + + +def test_only(tmp_path: Path): + given = """ + def f1(): + return 1 + + def is_used(x): + return x + """ + expected = """ + def f1(): + return 1 + + def is_used(x) -> bool: + return x + """ + + # prepare files and dirs + source_file = tmp_path / 'example.py' + source_file.write_text(dedent(given)) + # call the CLI + stream = StringIO() + code = main([str(tmp_path), '--only', 'name'], stream) + assert code == 0 + assert source_file.read_text() == dedent(expected) From 5e1ec018115b7df74fb0a61bd67771db5389e1d7 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 10 Feb 2023 13:33:18 +0100 Subject: [PATCH 3/3] bump version --- infer_types/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infer_types/__init__.py b/infer_types/__init__.py index 3401955..8bd1d6a 100644 --- a/infer_types/__init__.py +++ b/infer_types/__init__.py @@ -4,4 +4,4 @@ __all__ = ['entrypoint', 'main'] -__version__ = '0.3.1' +__version__ = '0.3.2'