diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 3ebde159a6..6a4e489f5b 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -10,6 +10,7 @@ */ import { assert } from '../common/debug'; +import { Uri } from '../common/uri/uri'; import { ArgCategory, AssignmentExpressionNode, @@ -95,6 +96,7 @@ import { specializeTupleClass, specializeWithUnknownTypeArgs, stripTypeForm, + synthesizeTypeVarForSelfCls, transformPossibleRecursiveTypeAlias, } from './typeUtils'; @@ -1814,7 +1816,18 @@ function narrowTypeForInstance( } } - if (isClass(subtype)) { + if ( + isClass(subtype) || + // when strictGenericNarrowing is enabled, we need to convert the Callable type to a Callable + // protocol to make sure it keeps the generics when narrowing, but this is only needed for + // isinstance checks because you can't specify generics to it + (!isTypeIsCheck && + getFileInfo(errorNode).diagnosticRuleSet.strictGenericNarrowing && + isFunction(subtype)) + ) { + if (isFunction(subtype)) { + subtype = synthesizeCallableProtocolFromFunctionType(evaluator, subtype, errorNode); + } return combineTypes( filterClassType( unexpandedSubtype, @@ -1850,6 +1863,54 @@ function narrowTypeForInstance( return filteredType; } +/** + * the logic for narrowing `typing.Callable` (`FunctionType`) is completely different to the logic + * for narrowing callable protocols. when `typing.Callable`s are narrowed, it does not retain the + * generics from the supertype, so we create a fake callable protocol from a `FunctionType` so it + * can be narrowed using the same logic that's used to narrow `ClassType`s. + * + * this is not ideal and probably super hacky, but i couldnt figure out how to update the narrowing + * logic for `FunctionType` so this solution was easier. + */ +const synthesizeCallableProtocolFromFunctionType = ( + evaluator: TypeEvaluator, + callable: FunctionType, + errorNode: ParseNode +): ClassType => { + //TODO: fix hover text. currently this causes narrowed `FunctionType`s to display like this: + // "" + const callableType = ClassType.createInstantiable( + 'Callable', + '', + '', + Uri.empty(), + ClassTypeFlags.ProtocolClass, + 0, + undefined, + undefined + ); + callableType.shared.baseClasses.push(evaluator.getBuiltInType(errorNode, 'Protocol')); + computeMroLinearization(callableType); + const fields = ClassType.getSymbolTable(callableType); + const callMethod = FunctionType.createSynthesizedInstance(callable.shared.name); + FunctionType.addParam( + callMethod, + FunctionParam.create( + ParamCategory.Simple, + synthesizeTypeVarForSelfCls(callableType, false), + FunctionParamFlags.TypeDeclared, + 'self' + ) + ); + for (const parameter of callable.shared.parameters) { + FunctionType.addParam(callMethod, parameter); + } + callMethod.shared.declaredReturnType = FunctionType.getEffectiveReturnType(callable); + const callSymbol = Symbol.createWithType(SymbolFlags.ClassMember, callMethod); + fields.set('__call__', callSymbol); + return ClassType.cloneAsInstance(callableType); +}; + // This function assumes that the caller has already verified that the two // types are the same class and are not literals. It also assumes that the // caller has verified that type1 is not assignable to type2 or vice versa. diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index c61d434fc6..2b889e4413 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -1089,10 +1089,7 @@ export const shouldUseVarianceForSpecialization = (typeToNarrow: Type, strictGen } return allSubtypes( typeToNarrow, - (subtype) => - subtype.category !== TypeCategory.Class || - // !ClassType.isSameGenericClass(subtype, narrowToType) || - subtype.shared.typeParams.length === 0 + (subtype) => subtype.category !== TypeCategory.Class || subtype.shared.typeParams.length === 0 ); }; diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py index d1f07adb6c..c5abf9b6a3 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py @@ -1,5 +1,5 @@ -from typing import Any, Never, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible - +from typing import Any, Never, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible, Callable, TypeIs +from types import FunctionType class Covariant[T]: def foo(self, other: object): @@ -115,4 +115,29 @@ def _( value: str | Foo[str], ): if isinstance(value, Foo): - assert_type(value, Foo[str]) \ No newline at end of file + assert_type(value, Foo[str]) + +def _(f: Callable[[int], str]): + if isinstance(f, staticmethod): + # can't use assert_type on the function itself, see TODO in synthesizeCallableProtocolFromFunctionType + assert_type(f(1), str) + assert_type(f.__call__, Callable[[int], str]) + reveal_type(f) + +class CallableProtocol[**P, T](Protocol): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... + + +def _(f: CallableProtocol[[], None]): + if isinstance(f, staticmethod): + assert_type(f, staticmethod[[], None]) + +def _(f: Callable[[int], str]): + if isinstance(f, FunctionType): + assert_type(f, FunctionType) + +def takes_arg(value: object) -> TypeIs[Callable[[int], None]]: ... + +def _(value: Callable[[], None] | Callable[[int], None]): + if takes_arg(value): + assert_type(value, Callable[[int], None]) \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBoundsDisabled.py b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBoundsDisabled.py index 7f5a658f4c..9860d43583 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBoundsDisabled.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBoundsDisabled.py @@ -1,4 +1,5 @@ -from typing import Any, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible +from typing import Any, assert_type, runtime_checkable, Protocol, Iterable, Iterator, MutableMapping, Reversible, Callable, TypeIs +from types import FunctionType class Covariant[T]: @@ -115,4 +116,26 @@ def _( value: str | Foo[str], ): if isinstance(value, Foo): - assert_type(value, Foo[str]) \ No newline at end of file + assert_type(value, Foo[str]) + +def _(f: Callable[[], None]): + if isinstance(f, staticmethod): + assert_type(f, staticmethod[..., Any]) + +class CallableProtocol[**P, T](Protocol): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... + + +def _(f: CallableProtocol[[], None]): + if isinstance(f, staticmethod): + assert_type(f, staticmethod[[], None]) + +def _(f: Callable[[int], str]): + if isinstance(f, FunctionType): + assert_type(f, FunctionType) + +def takes_arg(value: object) -> TypeIs[Callable[[int], None]]: ... + +def _(value: Callable[[], None] | Callable[[int], None]): + if takes_arg(value): + assert_type(value, Callable[[int], None]) \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts index c6e9efc1bc..74396c1dfd 100644 --- a/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluatorBased.test.ts @@ -126,6 +126,7 @@ describe('narrowing type vars using their bounds', () => { const analysisResults = typeAnalyzeSampleFiles(['typeNarrowingUsingBounds.py'], configOptions); validateResultsButBased(analysisResults, { errors: [], + infos: [{ line: 124, message: 'Type of "f" is ""' }], }); }); test('disabled', () => {