diff --git a/mypy/checker.py b/mypy/checker.py index c69b80a55fd9..dbaad1e48cee 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -65,10 +65,12 @@ CallExpr, ClassDef, ComparisonExpr, + ComplexExpr, Context, ContinueStmt, Decorator, DelStmt, + DictExpr, EllipsisExpr, Expression, ExpressionStmt, @@ -102,6 +104,7 @@ RaiseStmt, RefExpr, ReturnStmt, + SetExpr, StarExpr, Statement, StrExpr, @@ -352,6 +355,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # functions such as open(), etc. plugin: Plugin + # A helper state to produce unique temporary names on demand. + _unique_id: int + def __init__( self, errors: Errors, @@ -415,6 +421,7 @@ def __init__( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + self._unique_id = 0 @property def type_context(self) -> list[Type | None]: @@ -5412,19 +5419,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return def visit_match_stmt(self, s: MatchStmt) -> None: - named_subject: Expression - if isinstance(s.subject, CallExpr): - # Create a dummy subject expression to handle cases where a match statement's subject - # is not a literal value. This lets us correctly narrow types and check exhaustivity - # This is hack! - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id - v = Var(name) - named_subject = NameExpr(name) - named_subject.node = v - else: - named_subject = s.subject - + named_subject = self._make_named_statement_for_match(s.subject) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5501,6 +5496,38 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _make_named_statement_for_match(self, subject: Expression) -> Expression: + """Construct a fake NameExpr for inference if a match clause is complex.""" + expressions_to_preserve = ( + # Already named - we should infer type of it as given + NameExpr, + AssignmentExpr, + # Collection literals defined inline - we want to infer types of variables + # included there, not exprs as a whole + ListExpr, + DictExpr, + TupleExpr, + SetExpr, + # Primitive literals - their type is known, no need to name them + IntExpr, + StrExpr, + BytesExpr, + FloatExpr, + ComplexExpr, + EllipsisExpr, + ) + if isinstance(subject, expressions_to_preserve): + return subject + else: + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! + name = self.new_unique_dummy_name("match") + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + return named_subject + def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: @@ -7878,6 +7905,12 @@ def warn_deprecated_overload_item( if candidate == target: self.warn_deprecated(item.func, context) + def new_unique_dummy_name(self, namespace: str) -> str: + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" + name = f"dummy-{namespace}-{self._unique_id}" + self._unique_id += 1 + return name + # leafs def visit_pass_stmt(self, o: PassStmt, /) -> None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index ea6cc7ffe56a..fd5128581c3e 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1239,7 +1239,7 @@ def main() -> None: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail] +[case testMatchCapturePatternFromAsyncFunctionReturningUnion] async def func1(arg: bool) -> str | int: ... async def func2(arg: bool) -> bytes | int: ... @@ -2522,3 +2522,86 @@ def fn2(x: Some | int | str) -> None: case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], str]") pass [builtins fixtures/dict.pyi] + +[case testMatchFunctionCall] +# flags: --warn-unreachable + +def fn() -> int | str: ... + +match fn(): + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchAttribute] +# flags: --warn-unreachable + +class A: + foo: int | str + +match A().foo: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchOperations] +# flags: --warn-unreachable + +x: int +match -x: + case -1 as s: + reveal_type(s) # N: Revealed type is "Literal[-1]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 + 2: + case 3 as s: + reveal_type(s) # N: Revealed type is "Literal[3]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 > 2: + case True as s: + reveal_type(s) # N: Revealed type is "Literal[True]" + case False as s: + reveal_type(s) # N: Revealed type is "Literal[False]" + case other: + other # E: Statement is unreachable +[builtins fixtures/ops.pyi] + +[case testMatchDictItem] +# flags: --warn-unreachable + +m: dict[str, int | str] +k: str + +match m[k]: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[builtins fixtures/dict.pyi] + +[case testMatchLiteralValuePathological] +# flags: --warn-unreachable + +match 0: + case 0 as i: + reveal_type(i) # N: Revealed type is "Literal[0]?" + case int(i): + i # E: Statement is unreachable + case other: + other # E: Statement is unreachable