5252
5353
5454@overload
55- def expand_type (typ : CallableType , env : Mapping [TypeVarId , Type ]) -> CallableType : ...
55+ def expand_type (
56+ typ : CallableType , env : Mapping [TypeVarId , Type ], * , keep_none_type : bool = ...
57+ ) -> CallableType : ...
5658
5759
5860@overload
59- def expand_type (typ : ProperType , env : Mapping [TypeVarId , Type ]) -> ProperType : ...
61+ def expand_type (
62+ typ : ProperType , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = ...
63+ ) -> ProperType : ...
6064
6165
6266@overload
63- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type : ...
67+ def expand_type (
68+ typ : Type , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = ...
69+ ) -> Type : ...
6470
6571
66- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type :
72+ def expand_type (
73+ typ : Type , env : Mapping [TypeVarId , Type ], * , strict_optional : bool = False
74+ ) -> Type :
6775 """Substitute any type variable references in a type given by a type
6876 environment.
6977 """
70- return typ .accept (ExpandTypeVisitor (env ))
78+ return typ .accept (ExpandTypeVisitor (env , strict_optional ))
7179
7280
7381@overload
@@ -184,8 +192,9 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
184192
185193 variables : Mapping [TypeVarId , Type ] # TypeVar id -> TypeVar value
186194
187- def __init__ (self , variables : Mapping [TypeVarId , Type ]) -> None :
195+ def __init__ (self , variables : Mapping [TypeVarId , Type ], strict_optional : bool = False ) -> None :
188196 self .variables = variables
197+ self .strict_optional = strict_optional
189198 self .recursive_guard : set [Type | tuple [int , Type ]] = set ()
190199
191200 def visit_unbound_type (self , t : UnboundType ) -> Type :
@@ -460,7 +469,7 @@ def visit_union_type(self, t: UnionType) -> Type:
460469 # might be subtypes of others, however calling make_simplified_union()
461470 # can cause recursion, so we just remove strict duplicates.
462471 simplified = UnionType .make_union (
463- remove_trivial (flatten_nested_unions (expanded )), t .line , t .column
472+ remove_trivial (flatten_nested_unions (expanded ), self . strict_optional ), t .line , t .column
464473 )
465474 # This call to get_proper_type() is unfortunate but is required to preserve
466475 # the invariant that ProperType will stay ProperType after applying expand_type(),
@@ -508,7 +517,7 @@ def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
508517 return typ
509518
510519
511- def remove_trivial (types : Iterable [Type ]) -> list [Type ]:
520+ def remove_trivial (types : Iterable [Type ], strict_optional : bool = False ) -> list [Type ]:
512521 """Make trivial simplifications on a list of types without calling is_subtype().
513522
514523 This makes following simplifications:
@@ -523,7 +532,7 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
523532 p_t = get_proper_type (t )
524533 if isinstance (p_t , UninhabitedType ):
525534 continue
526- if isinstance (p_t , NoneType ) and not state .strict_optional :
535+ if isinstance (p_t , NoneType ) and not state .strict_optional and not strict_optional :
527536 removed_none = True
528537 continue
529538 if isinstance (p_t , Instance ) and p_t .type .fullname == "builtins.object" :
0 commit comments