Skip to content

Commit

Permalink
fix logic to determine whether to use generic bounds for narrowing wh…
Browse files Browse the repository at this point in the history
…en the value being narrowed is a union
  • Loading branch information
DetachHead committed Nov 22, 2024
1 parent e0ef69d commit 74f3b38
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
8 changes: 5 additions & 3 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1137,18 +1137,20 @@ export function getIsInstanceClassTypes(
): (ClassType | TypeVarType | FunctionType)[] | undefined {
let foundNonClassType = false;
const classTypeList: (ClassType | TypeVarType | FunctionType)[] = [];
const useVarianceForSpecialization = shouldUseVarianceForSpecialization(typeToNarrow, improvedGenericNarrowing);
// Create a helper function that returns a list of class types or
// undefined if any of the types are not valid.
const addClassTypesToList = (types: Type[]) => {
types.forEach((type) => {
const subtypes: Type[] = [];
if (isClass(type)) {
evaluator.inferVarianceForClass(type);
const useVariance = shouldUseVarianceForSpecialization(typeToNarrow, improvedGenericNarrowing);
if (useVariance) {
evaluator.inferVarianceForClass(type);
}
type = specializeWithUnknownTypeArgs(
type,
evaluator.getTupleClassType(),
useVarianceForSpecialization ? evaluator.getObjectType() : undefined
useVariance ? evaluator.getObjectType() : undefined
);

doForEachSubtype(type, (subtype) => {
Expand Down
18 changes: 13 additions & 5 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,7 @@ export function someSubtypes(type: Type, callback: (type: Type) => boolean): boo

export function allSubtypes(type: Type, callback: (type: Type) => boolean): boolean {
if (isUnion(type)) {
return type.priv.subtypes.every((subtype) => {
callback(subtype);
});
return type.priv.subtypes.every((subtype) => callback(subtype));
} else {
return callback(type);
}
Expand Down Expand Up @@ -1085,8 +1083,18 @@ export function getTypeVarScopeIds(type: Type): TypeVarScopeId[] {
* If the type we're narrowing already has type parameters,
* there's no need to use variance for specialization.
*/
export const shouldUseVarianceForSpecialization = (type: Type, improvedGenericNarrowing: boolean) =>
improvedGenericNarrowing && (type.category !== TypeCategory.Class || type.shared.typeParams.length === 0);
export const shouldUseVarianceForSpecialization = (typeToNarrow: Type, improvedGenericNarrowing: boolean) => {
if (!improvedGenericNarrowing) {
return false;
}
return allSubtypes(
typeToNarrow,
(subtype) =>
subtype.category !== TypeCategory.Class ||
// !ClassType.isSameGenericClass(subtype, narrowToType) ||
subtype.shared.typeParams.length === 0
);
};

/**
* Specializes the class with "Unknown" type args (or the equivalent for ParamSpecs or TypeVarTuples), or its
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Never, assert_type, Iterable, Iterator, MutableMapping, Reversible
from typing import Any, Never, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible


class Covariant[T]:
Expand Down Expand Up @@ -106,3 +106,13 @@ class Constraints[T: (int, str), U: (int, str), V: int]:
def _(value: object):
if isinstance(value, Constraints):
assert_type(value, Constraints[int, int, int] | Constraints[int, str, int] | Constraints[str, int, int] | Constraints[str, str, int])

@runtime_checkable
class Foo[T: (int, str)](Protocol):
def asdf(self): ...

def _(
value: str | Foo[str],
):
if isinstance(value, Foo):
assert_type(value, Foo[str])
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, assert_type, Iterable, Iterator, MutableMapping, Reversible
from typing import Any, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible


class Covariant[T]:
Expand Down Expand Up @@ -106,3 +106,13 @@ class Constraints[T: (int, str), U: (int, str), V: int]:
def _(value: object):
if isinstance(value, Constraints):
assert_type(value, Constraints[Any, Any, Any])

@runtime_checkable
class Foo[T: (int, str)](Protocol):
def asdf(self): ...

def _(
value: str | Foo[str],
):
if isinstance(value, Foo):
assert_type(value, Foo[str])

0 comments on commit 74f3b38

Please sign in to comment.