Skip to content

Commit

Permalink
don't do the variance narrowing thing if the type was originally `Any…
Browse files Browse the repository at this point in the history
…`/Unknown
  • Loading branch information
DetachHead committed Oct 6, 2024
1 parent 1bc9a3d commit 6e30215
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3794,7 +3794,7 @@ export class Checker extends ParseTreeWalker {
return;
}

const classTypeList = getIsInstanceClassTypes(this._evaluator, arg1Type);
const classTypeList = getIsInstanceClassTypes(this._evaluator, arg1Type, arg0Type);
if (!classTypeList) {
return;
}
Expand Down
8 changes: 7 additions & 1 deletion packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,13 @@ function narrowTypeBasedOnClassPattern(
if (isClass(exprType) && !exprType.props?.typeAliasInfo) {
exprType = ClassType.cloneRemoveTypePromotions(exprType);
evaluator.inferVarianceForClass(exprType);
exprType = specializeWithUnknownTypeArgs(exprType, evaluator.getTupleClassType(), evaluator.getObjectType());
exprType = specializeWithUnknownTypeArgs(
exprType,
evaluator.getTupleClassType(),
// for backwards compatibility with bacly typed code, we don't specialize using variance if the type we're
// narrowing is Any/Unknown
isAnyOrUnknown(type) ? undefined : evaluator.getObjectType()
);
}

// Are there any positional arguments? If so, try to get the mappings for
Expand Down
17 changes: 13 additions & 4 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,11 @@ export function getTypeNarrowingCallback(
const arg1TypeResult = evaluator.getTypeOfExpression(arg1Expr, EvalFlags.IsInstanceArgDefaults);
const arg1Type = arg1TypeResult.type;

const classTypeList = getIsInstanceClassTypes(evaluator, arg1Type);
const classTypeList = getIsInstanceClassTypes(
evaluator,
arg1Type,
evaluator.getTypeOfExpression(arg0Expr).type
);
const isIncomplete = !!callTypeResult.isIncomplete || !!arg1TypeResult.isIncomplete;

if (classTypeList) {
Expand Down Expand Up @@ -1125,11 +1129,16 @@ function narrowTypeForIsEllipsis(evaluator: TypeEvaluator, node: ExpressionNode,
// which form and returns a list of classes or undefined.
export function getIsInstanceClassTypes(
evaluator: TypeEvaluator,
argType: Type
argType: Type,
typeToNarrow: Type
): (ClassType | TypeVarType | FunctionType)[] | undefined {
let foundNonClassType = false;
const classTypeList: (ClassType | TypeVarType | FunctionType)[] = [];

/**
* if the type we're narrowing is Any or Unknown, we don't want to specialize using the
* variance/bound for compatibility with less strictly typed code (cringe)
*/
const useVarianceForSpecialization = !isAnyOrUnknown(typeToNarrow);
// Create a helper function that returns a list of class types or
// undefined if any of the types are not valid.
const addClassTypesToList = (types: Type[]) => {
Expand All @@ -1139,7 +1148,7 @@ export function getIsInstanceClassTypes(
subtype = specializeWithUnknownTypeArgs(
subtype,
evaluator.getTupleClassType(),
evaluator.getObjectType()
useVarianceForSpecialization ? evaluator.getObjectType() : undefined
);

if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

def func1(v: Any) -> bool:
if isinstance(v, Iterable):
reveal_type(v, expected_text="Iterable[object]")
reveal_type(v, expected_text="Iterable[Unknown]")
if isinstance(v, Sized):
reveal_type(v, expected_text="<subclass of Iterable[object] and Sized>")
reveal_type(v, expected_text="<subclass of Iterable and Sized>")
return True
return False
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Never, assert_type, Iterable
from typing import Any, Never, assert_type, Iterable


class Covariant[T]:
Expand Down Expand Up @@ -40,4 +40,14 @@ def bar(self, other: T): ...
def foo(value: object):
match value:
case Iterable():
assert_type(value, Iterable[object])
assert_type(value, Iterable[object])

class AnyOrUnknown:
"""for backwards compatibility with badly typed code we keep the old functionality when narrowing `Any`/Unknown"""
def foo(self, value: Any):
if isinstance(value, Iterable):
assert_type(value, Iterable[Any])
def bar(self, value: Any):
match value:
case Iterable():
assert_type(value, Iterable[Any])

0 comments on commit 6e30215

Please sign in to comment.