diff --git a/pytype/abstract/_pytd_function.py b/pytype/abstract/_pytd_function.py index 90200c396..10a4bfac2 100644 --- a/pytype/abstract/_pytd_function.py +++ b/pytype/abstract/_pytd_function.py @@ -260,40 +260,57 @@ def compatible_with(new, existing, view): name = mutation.name values = mutation.value if obj.from_annotation: - params = obj.get_instance_type_parameter(name) - ps = {v for v in params.data if should_check(v)} - if ps: - filtered_values = self.ctx.program.NewVariable() - # check if the container type is being broadened. - new = [] - short_name = name.rsplit(".", 1)[-1] - for b in values.bindings: - if not should_check(b.data) or b.data in ps: - filtered_values.PasteBinding(b) - continue - new_view = datatypes.AccessTrackingDict.merge( - combined_view, view, {values: b}) - if not compatible_with(values, ps, new_view): - combination = [b] - bad_param = b.data.get_instance_type_parameter(short_name) - if bad_param in new_view: - combination.append(new_view[bad_param]) - if not node.HasCombination(combination): - # Since HasCombination is expensive, we don't use it to - # pre-filter bindings, but once we think we have an error, we - # should double-check that the binding is actually visible. We - # also drop non-visible bindings from filtered_values. - continue - filtered_values.PasteBinding(b) - new.append(b.data) - # By updating filtered_mutations only when ps is non-empty, we - # filter out mutations to parameters with type Any. - filtered_mutations.append( - function.Mutation(obj, name, filtered_values)) - if new: - errors[obj][short_name] = (params, values, obj.from_annotation) + # We should check for parameter mismatches only if the class is + # generic. Consider: + # class A(tuple[int, int]): ... + # class B(tuple): ... + # Although pytype computes mutations for tuple.__new__ for both + # classes, the second implicitly inherits from tuple[Any, ...], so + # there are no restrictions on the container contents. + check_params = False + for cls in obj.cls.mro: + if isinstance(cls, _classes.ParameterizedClass): + check_params = True + break + elif cls.template: + break else: + check_params = False + if not check_params: filtered_mutations.append(function.Mutation(obj, name, values)) + continue + params = obj.get_instance_type_parameter(name) + ps = {v for v in params.data if should_check(v)} + if ps: + filtered_values = self.ctx.program.NewVariable() + # check if the container type is being broadened. + new = [] + short_name = name.rsplit(".", 1)[-1] + for b in values.bindings: + if not should_check(b.data) or b.data in ps: + filtered_values.PasteBinding(b) + continue + new_view = datatypes.AccessTrackingDict.merge( + combined_view, view, {values: b}) + if not compatible_with(values, ps, new_view): + combination = [b] + bad_param = b.data.get_instance_type_parameter(short_name) + if bad_param in new_view: + combination.append(new_view[bad_param]) + if not node.HasCombination(combination): + # Since HasCombination is expensive, we don't use it to + # pre-filter bindings, but once we think we have an error, we + # should double-check that the binding is actually visible. We + # also drop non-visible bindings from filtered_values. + continue + filtered_values.PasteBinding(b) + new.append(b.data) + # By updating filtered_mutations only when ps is non-empty, we + # filter out mutations to parameters with type Any. + filtered_mutations.append( + function.Mutation(obj, name, filtered_values)) + if new: + errors[obj][short_name] = (params, values, obj.from_annotation) all_mutations = filtered_mutations diff --git a/pytype/config.py b/pytype/config.py index ee2dfbf4b..30b3ee31f 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -242,7 +242,7 @@ def add_options(o, arglist): _flag("--overriding-renamed-parameter-count-checks", False, "Enable parameter count checks for overriding methods with " "renamed arguments."), - _flag("--use-enum-overlay", False, + _flag("--use-enum-overlay", True, "Use the enum overlay for more precise enum checking."), _flag("--strict-none-binding", False, "Variables initialized as None retain their None binding."), diff --git a/pytype/overlays/fiddle_overlay.py b/pytype/overlays/fiddle_overlay.py index 065d232a7..e38415266 100644 --- a/pytype/overlays/fiddle_overlay.py +++ b/pytype/overlays/fiddle_overlay.py @@ -24,6 +24,14 @@ _INSTANCE_CACHE: Dict[Tuple[Node, abstract.Class, str], abstract.Instance] = {} +_CLASS_ALIASES = { + "Config": "Config", + "PaxConfig": "Config", + "Partial": "Partial", + "PaxPartial": "Partial" +} + + class FiddleOverlay(overlay.Overlay): """A custom overlay for the 'fiddle' module.""" @@ -61,7 +69,7 @@ def __init__(self, name, ctx, module): mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) # For consistency with the rest of the overlay - self.fiddle_type_name = name + self.fiddle_type_name = _CLASS_ALIASES[name] self.module = module def __repr__(self): @@ -150,7 +158,9 @@ def getitem_slot(self, node, index_var) -> Tuple[Node, abstract.Instance]: """Specialize the generic class with the value of index_var.""" underlying = index_var.data[0] - ret = BuildableType(self.name, underlying, self.ctx, module=self.module) + ret = BuildableType( + self.fiddle_type_name, underlying, self.ctx, module=self.module + ) return node, ret.to_variable(node) def get_own_new(self, node, value) -> Tuple[Node, Variable]: @@ -177,12 +187,16 @@ def __init__( super().__init__(base_cls, formal_type_parameters, ctx, template) # pytype: disable=wrong-arg-types self.fiddle_type_name = fiddle_type_name self.underlying = underlying + self.module = module def replace(self, inner_types): inner_types = dict(inner_types) new_underlying = inner_types[abstract_utils.T] typ = self.__class__ - return typ(self.fiddle_type_name, new_underlying, self.ctx, self.template) + return typ( + self.fiddle_type_name, new_underlying, self.ctx, self.template, + self.module + ) def instantiate(self, node, container=None): _, ret = make_instance( @@ -244,6 +258,7 @@ def make_instance( ) -> Tuple[Node, abstract.BaseValue]: """Generate a Buildable instance from an underlying template class.""" + subclass_name = _CLASS_ALIASES[subclass_name] if subclass_name not in ("Config", "Partial"): raise ValueError(f"Unexpected instance class: {subclass_name}") diff --git a/pytype/pyi/definitions.py b/pytype/pyi/definitions.py index f605fc0fd..7e01cd238 100644 --- a/pytype/pyi/definitions.py +++ b/pytype/pyi/definitions.py @@ -173,6 +173,10 @@ def _convert_annotated(x): """Convert everything to a string to store it in pytd.Annotated.""" if isinstance(x, types.Pyval): return x.repr_str() + # TODO(rechen): ast.unparse is new in Python 3.9, so we can drop the hasattr + # check once pytype stops supporting 3.8. + elif isinstance(x, astlib.AST) and hasattr(astlib, "unparse"): + return astlib.unparse(x) elif isinstance(x, dict): return metadata.to_string(x) elif isinstance(x, tuple): diff --git a/pytype/pyi/parser_test.py b/pytype/pyi/parser_test.py index 7a3e4a3f2..3cb5674f4 100644 --- a/pytype/pyi/parser_test.py +++ b/pytype/pyi/parser_test.py @@ -6,6 +6,7 @@ from pytype.pyi import parser_test_base from pytype.pytd import pytd from pytype.tests import test_base +from pytype.tests import test_utils import unittest @@ -2846,6 +2847,20 @@ class Foo: y: Annotated[int, {'tag': 'call', 'fn': 'unit', 'posargs': ('s',), 'kwargs': {'exp': 9}}] """) + @test_utils.skipBeforePy((3, 9), "requires ast.unparse, new in 3.9") + def test_name(self): + self.check(""" + from typing_extensions import Annotated + + class Foo: + x: Annotated[int, Signal] + """, """ + from typing_extensions import Annotated + + class Foo: + x: Annotated[int, Signal] + """) + class ErrorTest(test_base.UnitTest): """Test parser errors.""" diff --git a/pytype/pytd/printer.py b/pytype/pytd/printer.py index 492c30990..204b77da5 100644 --- a/pytype/pytd/printer.py +++ b/pytype/pytd/printer.py @@ -723,7 +723,7 @@ def VisitGenericType(self, node): assert isinstance(param, (pytd.NothingType, pytd.TypeParameter)), param parameters = ("...",) + parameters[1:] return (self.MaybeCapitalize(node.base_type) + - "[" + ", ".join(parameters) + "]") + "[" + ", ".join(str(p) for p in parameters) + "]") def VisitCallableType(self, node): typ = self.MaybeCapitalize(node.base_type) diff --git a/pytype/tests/test_test_code.py b/pytype/tests/test_test_code.py index 894816712..6880f1d14 100644 --- a/pytype/tests/test_test_code.py +++ b/pytype/tests/test_test_code.py @@ -142,6 +142,22 @@ def test_bar(self): other_mock.return_value.__enter__ = lambda x: x """) + def test_decorated_setup(self): + self.Check(""" + from typing import Any + import unittest + from unittest import mock + + random_module: Any + + class FooTest(unittest.TestCase): + @mock.patch.object(random_module, 'attr') + def setUp(self): + self.x = 42 + def test_something(self): + assert_type(self.x, int) + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_tuple2.py b/pytype/tests/test_tuple2.py index 37b841733..e0e39ee6b 100644 --- a/pytype/tests/test_tuple2.py +++ b/pytype/tests/test_tuple2.py @@ -253,6 +253,33 @@ def f(x: tuple[int, int]): f((0, 0)) # ok """) + def test_imported_tuple_subclass_with_new(self): + with self.DepTree([("foo.pyi", """ + from typing import TypeVar + _T = TypeVar('_T', bound=C) + class C(tuple): + def __new__( + cls: type[_T], x: str | list[tuple[int, tuple[int, int]]] + ) -> _T: ... + """)]): + ty = self.Infer(""" + import foo + class A: + def __init__(self, c: foo.C): + self.c = foo.C('+'.join([f'{x}{y}' for x, y in c])) + class B: + def __init__(self, c: foo.C = foo.C([(0, (1, 2))])): + pass + """) + self.assertTypesMatchPytd(ty, """ + import foo + class A: + c: foo.C + def __init__(self, c: foo.C) -> None: ... + class B: + def __init__(self, c: foo.C = ...) -> None: ... + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tools/analyze_project/pytype_runner_test.py b/pytype/tools/analyze_project/pytype_runner_test.py index 3c7f5e9f5..e21b21dbe 100644 --- a/pytype/tools/analyze_project/pytype_runner_test.py +++ b/pytype/tools/analyze_project/pytype_runner_test.py @@ -232,7 +232,7 @@ def setUp(self): def assertFlags(self, flags, expected_flags): # Add temporary flags that are set to true by default here, so that they are # filtered out of tests. - temporary_flags = set() + temporary_flags = {'--use-enum-overlay'} self.assertEqual(flags - temporary_flags, expected_flags) # --disable tests a flag with a string value. diff --git a/pytype/tracer_vm.py b/pytype/tracer_vm.py index 304a4fd57..f8791a641 100644 --- a/pytype/tracer_vm.py +++ b/pytype/tracer_vm.py @@ -401,7 +401,8 @@ def bind(cur_node, m): cls = valself.data.cls bound_method = bind(node, method) if obj == cls else method if (not isinstance(cls, abstract.InterpreterClass) or - any(isinstance(m, abstract.FUNCTION_TYPES) for m in bound_method.data)): + any(isinstance(m, abstract.INTERPRETER_FUNCTION_TYPES) + for m in bound_method.data)): return node, bound_method # If the method is not something that pytype recognizes as a function - # which can happen if the method is decorated, for example - then we look up