diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index da8476e5b..b157eecf7 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -163,6 +163,11 @@ test('With2', () => { TestUtils.validateResults(analysisResults, 3); }); +test('context manager where __exit__ returns bool | None', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py']); + TestUtils.validateResultsButBased(analysisResults, { unreachableCodes: [{ line: 47 }], unusedCodes: undefined }); +}); + test('With3', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['with3.py']); diff --git a/packages/pyright-internal/src/tests/samples/withBased.py b/packages/pyright-internal/src/tests/samples/withBased.py new file mode 100644 index 000000000..1ff09abbf --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/withBased.py @@ -0,0 +1,48 @@ +import contextlib +from types import TracebackType +from typing import Iterator, Literal + +from typing_extensions import assert_never + +class BoolOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + ... + +def _(): + with BoolOrNone(): + raise Exception + print(1) # reachable + +class TrueOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> Literal[True] | None: + ... + +def _(): + with TrueOrNone(): + raise Exception + print(1) # reachable + + +class FalseOrNone(contextlib.AbstractContextManager[None]): + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> Literal[False] | None: + ... + +def _(): + with FalseOrNone(): + raise Exception + print(1) # unreachable \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/testUtils.ts b/packages/pyright-internal/src/tests/testUtils.ts index 2becc39a6..7f71e0003 100644 --- a/packages/pyright-internal/src/tests/testUtils.ts +++ b/packages/pyright-internal/src/tests/testUtils.ts @@ -245,7 +245,11 @@ export const validateResultsButBased = (allResults: FileAnalysisResult[], expect code: result.getRule() as DiagnosticRule | undefined, }) ); - const expectedResult = expectedResults[diagnosticType] ?? []; - expect(new Set(actualResult)).toEqual(new Set(expectedResult.map(expect.objectContaining))); + const expectedResult = expectedResults[diagnosticType]; + // if it's explicitly in the expected results as undefined, that means we don't care. + // if it's not in the expected results at all, then check it + if (!(diagnosticType in expectedResults) || expectedResult !== undefined) { + expect(new Set(actualResult)).toEqual(new Set((expectedResult ?? []).map(expect.objectContaining))); + } } };