From d47fba1e4aeeb18085900dfbbcd187e90d536913 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Fri, 20 Dec 2024 16:31:15 +0530 Subject: [PATCH] [red-knot] Add support for unpacking union types (#15052) ## Summary Refer: https://github.com/astral-sh/ruff/issues/13773#issuecomment-2548020368 This PR adds support for unpacking union types. Unpacking a union type requires us to first distribute the types for all the targets that are involved in an unpacking. For example, if there are two targets and a union type that needs to be unpacked, each target will get a type from each element in the union type. For example, if the type is `tuple[int, int] | tuple[int, str]` and the target has two elements `(a, b)`, then * The type of `a` will be a union of `int` and `int` which are at index 0 in the first and second tuple respectively which resolves to an `int`. * Similarly, the type of `b` will be a union of `int` and `str` which are at index 1 in the first and second tuple respectively which will be `int | str`. ### Refactors There are couple of refactors that are added in this PR: * Add a `debug_assertion` to validate that the unpack target is a list or a tuple * Add a separate method to handle starred expression ## Test Plan Update `unpacking.md` with additional test cases that uses union types. This is done using parameter type hints style. --- .../resources/mdtest/unpacking.md | 166 ++++++++++++++ crates/red_knot_python_semantic/src/types.rs | 7 + .../src/types/infer.rs | 4 +- .../src/types/unpacker.rs | 214 +++++++++++------- 4 files changed, 310 insertions(+), 81 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md index 1da17a4f27870..1ee1c5edf0094 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md @@ -306,3 +306,169 @@ reveal_type(b) # revealed: Unknown reveal_type(a) # revealed: LiteralString reveal_type(b) # revealed: LiteralString ``` + +## Union + +### Same types + +Union of two tuples of equal length and each element is of the same type. + +```py +def _(arg: tuple[int, int] | tuple[int, int]): + (a, b) = arg + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int +``` + +### Mixed types (1) + +Union of two tuples of equal length and one element differs in its type. + +```py +def _(arg: tuple[int, int] | tuple[int, str]): + a, b = arg + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int | str +``` + +### Mixed types (2) + +Union of two tuples of equal length and both the element types are different. + +```py +def _(arg: tuple[int, str] | tuple[str, int]): + a, b = arg + reveal_type(a) # revealed: int | str + reveal_type(b) # revealed: str | int +``` + +### Mixed types (3) + +Union of three tuples of equal length and various combination of element types: + +1. All same types +1. One different type +1. All different types + +```py +def _(arg: tuple[int, int, int] | tuple[int, str, bytes] | tuple[int, int, str]): + a, b, c = arg + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int | str + reveal_type(c) # revealed: int | bytes | str +``` + +### Nested + +```py +def _(arg: tuple[int, tuple[str, bytes]] | tuple[tuple[int, bytes], Literal["ab"]]): + a, (b, c) = arg + reveal_type(a) # revealed: int | tuple[int, bytes] + reveal_type(b) # revealed: str + reveal_type(c) # revealed: bytes | LiteralString +``` + +### Starred expression + +```py +def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]): + a, *b, c = arg + reveal_type(a) # revealed: int + # TODO: Should be `list[bytes | int | str]` + reveal_type(b) # revealed: @Todo(starred unpacking) + reveal_type(c) # revealed: int | bytes +``` + +### Size mismatch (1) + +```py +def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]): + # TODO: Add diagnostic (too many values to unpack) + a, b = arg + reveal_type(a) # revealed: int + reveal_type(b) # revealed: bytes | int +``` + +### Size mismatch (2) + +```py +def _(arg: tuple[int, bytes] | tuple[int, str]): + # TODO: Add diagnostic (there aren't enough values to unpack) + a, b, c = arg + reveal_type(a) # revealed: int + reveal_type(b) # revealed: bytes | str + reveal_type(c) # revealed: Unknown +``` + +### Same literal types + +```py +def _(flag: bool): + if flag: + value = (1, 2) + else: + value = (3, 4) + + a, b = value + reveal_type(a) # revealed: Literal[1, 3] + reveal_type(b) # revealed: Literal[2, 4] +``` + +### Mixed literal types + +```py +def _(flag: bool): + if flag: + value = (1, 2) + else: + value = ("a", "b") + + a, b = value + reveal_type(a) # revealed: Literal[1] | Literal["a"] + reveal_type(b) # revealed: Literal[2] | Literal["b"] +``` + +### Typing literal + +```py +from typing import Literal + +def _(arg: tuple[int, int] | Literal["ab"]): + a, b = arg + reveal_type(a) # revealed: int | LiteralString + reveal_type(b) # revealed: int | LiteralString +``` + +### Custom iterator (1) + +```py +class Iterator: + def __next__(self) -> tuple[int, int] | tuple[int, str]: + return (1, 2) + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +((a, b), c) = Iterable() +reveal_type(a) # revealed: int +reveal_type(b) # revealed: int | str +reveal_type(c) # revealed: tuple[int, int] | tuple[int, str] +``` + +### Custom iterator (2) + +```py +class Iterator: + def __next__(self) -> bytes: + return b"" + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +def _(arg: tuple[int, str] | Iterable): + a, b = arg + reveal_type(a) # revealed: int | bytes + reveal_type(b) # revealed: str | bytes +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index d655e18d8f848..87c5df9482999 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -580,6 +580,13 @@ impl<'db> Type<'db> { .expect("Expected a Type::KnownInstance variant") } + pub const fn into_tuple(self) -> Option> { + match self { + Type::Tuple(tuple_type) => Some(tuple_type), + _ => None, + } + } + pub const fn is_boolean_literal(&self) -> bool { matches!(self, Type::BooleanLiteral(..)) } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 084451ee50d1d..cb7376de1c422 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -207,8 +207,8 @@ fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult let result = infer_expression_types(db, value); let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope)); - let mut unpacker = Unpacker::new(db, file); - unpacker.unpack(unpack.target(db), value_ty, scope); + let mut unpacker = Unpacker::new(db, scope); + unpacker.unpack(unpack.target(db), value_ty); unpacker.finish() } diff --git a/crates/red_knot_python_semantic/src/types/unpacker.rs b/crates/red_knot_python_semantic/src/types/unpacker.rs index aa1820357aa7a..c22fdda4378c1 100644 --- a/crates/red_knot_python_semantic/src/types/unpacker.rs +++ b/crates/red_knot_python_semantic/src/types/unpacker.rs @@ -1,27 +1,30 @@ use std::borrow::Cow; -use ruff_db::files::File; -use ruff_python_ast::{self as ast, AnyNodeRef}; use rustc_hash::FxHashMap; +use ruff_python_ast::{self as ast, AnyNodeRef}; + use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; use crate::semantic_index::symbol::ScopeId; use crate::types::{todo_type, Type, TypeCheckDiagnostics}; use crate::Db; use super::context::{InferContext, WithDiagnostics}; +use super::{TupleType, UnionType}; /// Unpacks the value expression type to their respective targets. pub(crate) struct Unpacker<'db> { context: InferContext<'db>, + scope: ScopeId<'db>, targets: FxHashMap>, } impl<'db> Unpacker<'db> { - pub(crate) fn new(db: &'db dyn Db, file: File) -> Self { + pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { Self { - context: InferContext::new(db, file), + context: InferContext::new(db, scope.file(db)), targets: FxHashMap::default(), + scope, } } @@ -29,98 +32,151 @@ impl<'db> Unpacker<'db> { self.context.db() } - pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>, scope: ScopeId<'db>) { + /// Unpack the value type to the target expression. + pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>) { + debug_assert!( + matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)), + "Unpacking target must be a list or tuple expression" + ); + + self.unpack_inner(target, value_ty); + } + + fn unpack_inner(&mut self, target: &ast::Expr, value_ty: Type<'db>) { match target { ast::Expr::Name(target_name) => { - self.targets - .insert(target_name.scoped_expression_id(self.db(), scope), value_ty); + self.targets.insert( + target_name.scoped_expression_id(self.db(), self.scope), + value_ty, + ); } ast::Expr::Starred(ast::ExprStarred { value, .. }) => { - self.unpack(value, value_ty, scope); + self.unpack_inner(value, value_ty); } ast::Expr::List(ast::ExprList { elts, .. }) - | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty { - Type::Tuple(tuple_ty) => { - let starred_index = elts.iter().position(ast::Expr::is_starred_expr); - - let element_types = if let Some(starred_index) = starred_index { - if tuple_ty.len(self.db()) >= elts.len() - 1 { - let mut element_types = Vec::with_capacity(elts.len()); - element_types.extend_from_slice( - // SAFETY: Safe because of the length check above. - &tuple_ty.elements(self.db())[..starred_index], - ); - - // E.g., in `(a, *b, c, d) = ...`, the index of starred element `b` - // is 1 and the remaining elements after that are 2. - let remaining = elts.len() - (starred_index + 1); - // This index represents the type of the last element that belongs - // to the starred expression, in an exclusive manner. - let starred_end_index = tuple_ty.len(self.db()) - remaining; - // SAFETY: Safe because of the length check above. - let _starred_element_types = - &tuple_ty.elements(self.db())[starred_index..starred_end_index]; - // TODO: Combine the types into a list type. If the - // starred_element_types is empty, then it should be `List[Any]`. - // combine_types(starred_element_types); - element_types.push(todo_type!("starred unpacking")); - - element_types.extend_from_slice( - // SAFETY: Safe because of the length check above. - &tuple_ty.elements(self.db())[starred_end_index..], - ); - Cow::Owned(element_types) - } else { - let mut element_types = tuple_ty.elements(self.db()).to_vec(); - // Subtract 1 to insert the starred expression type at the correct - // index. - element_types.resize(elts.len() - 1, Type::Unknown); - // TODO: This should be `list[Unknown]` - element_types.insert(starred_index, todo_type!("starred unpacking")); - Cow::Owned(element_types) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + // Initialize the vector of target types, one for each target. + // + // This is mainly useful for the union type where the target type at index `n` is + // going to be a union of types from every union type element at index `n`. + // + // For example, if the type is `tuple[int, int] | tuple[int, str]` and the target + // has two elements `(a, b)`, then + // * The type of `a` will be a union of `int` and `int` which are at index 0 in the + // first and second tuple respectively which resolves to an `int`. + // * Similarly, the type of `b` will be a union of `int` and `str` which are at + // index 1 in the first and second tuple respectively which will be `int | str`. + let mut target_types = vec![vec![]; elts.len()]; + + let unpack_types = match value_ty { + Type::Union(union_ty) => union_ty.elements(self.db()), + _ => std::slice::from_ref(&value_ty), + }; + + for ty in unpack_types.iter().copied() { + // Deconstruct certain types to delegate the inference back to the tuple type + // for correct handling of starred expressions. + let ty = match ty { + Type::StringLiteral(string_literal_ty) => { + // We could go further and deconstruct to an array of `StringLiteral` + // with each individual character, instead of just an array of + // `LiteralString`, but there would be a cost and it's not clear that + // it's worth it. + Type::tuple( + self.db(), + std::iter::repeat(Type::LiteralString) + .take(string_literal_ty.python_len(self.db())), + ) } - } else { - Cow::Borrowed(tuple_ty.elements(self.db()).as_ref()) + _ => ty, }; - for (index, element) in elts.iter().enumerate() { - self.unpack( - element, - element_types.get(index).copied().unwrap_or(Type::Unknown), - scope, - ); + if let Some(tuple_ty) = ty.into_tuple() { + let tuple_ty_elements = self.tuple_ty_elements(elts, tuple_ty); + + // TODO: Add diagnostic for length mismatch + + for (index, ty) in tuple_ty_elements.iter().enumerate() { + if let Some(element_types) = target_types.get_mut(index) { + element_types.push(*ty); + } + } + } else { + let ty = if ty.is_literal_string() { + Type::LiteralString + } else { + ty.iterate(self.db()) + .unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target)) + }; + for target_type in &mut target_types { + target_type.push(ty); + } } } - Type::StringLiteral(string_literal_ty) => { - // Deconstruct the string literal to delegate the inference back to the - // tuple type for correct handling of starred expressions. We could go - // further and deconstruct to an array of `StringLiteral` with each - // individual character, instead of just an array of `LiteralString`, but - // there would be a cost and it's not clear that it's worth it. - let value_ty = Type::tuple( - self.db(), - std::iter::repeat(Type::LiteralString) - .take(string_literal_ty.python_len(self.db())), - ); - self.unpack(target, value_ty, scope); - } - _ => { - let value_ty = if value_ty.is_literal_string() { - Type::LiteralString - } else { - value_ty - .iterate(self.db()) - .unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target)) + + for (index, element) in elts.iter().enumerate() { + // SAFETY: `target_types` is initialized with the same length as `elts`. + let element_ty = match target_types[index].as_slice() { + [] => Type::Unknown, + types => UnionType::from_elements(self.db(), types), }; - for element in elts { - self.unpack(element, value_ty, scope); - } + self.unpack_inner(element, element_ty); } - }, + } _ => {} } } + /// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there + /// can be a starred expression in the `elements`. + fn tuple_ty_elements( + &mut self, + targets: &[ast::Expr], + tuple_ty: TupleType<'db>, + ) -> Cow<'_, [Type<'db>]> { + // If there is a starred expression, it will consume all of the entries at that location. + let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else { + // Otherwise, the types will be unpacked 1-1 to the elements. + return Cow::Borrowed(tuple_ty.elements(self.db()).as_ref()); + }; + + if tuple_ty.len(self.db()) >= targets.len() - 1 { + let mut element_types = Vec::with_capacity(targets.len()); + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(self.db())[..starred_index], + ); + + // E.g., in `(a, *b, c, d) = ...`, the index of starred element `b` + // is 1 and the remaining elements after that are 2. + let remaining = targets.len() - (starred_index + 1); + // This index represents the type of the last element that belongs + // to the starred expression, in an exclusive manner. + let starred_end_index = tuple_ty.len(self.db()) - remaining; + // SAFETY: Safe because of the length check above. + let _starred_element_types = + &tuple_ty.elements(self.db())[starred_index..starred_end_index]; + // TODO: Combine the types into a list type. If the + // starred_element_types is empty, then it should be `List[Any]`. + // combine_types(starred_element_types); + element_types.push(todo_type!("starred unpacking")); + + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(self.db())[starred_end_index..], + ); + Cow::Owned(element_types) + } else { + let mut element_types = tuple_ty.elements(self.db()).to_vec(); + // Subtract 1 to insert the starred expression type at the correct + // index. + element_types.resize(targets.len() - 1, Type::Unknown); + // TODO: This should be `list[Unknown]` + element_types.insert(starred_index, todo_type!("starred unpacking")); + Cow::Owned(element_types) + } + } + pub(crate) fn finish(mut self) -> UnpackResult<'db> { self.targets.shrink_to_fit(); UnpackResult {