Skip to content

Commit

Permalink
cleanup patches after test in pytest plugin (#1148) (#1164)
Browse files Browse the repository at this point in the history
* cleanup patches after test in pytest plugin (#1148)

* pin dependency in docs causing problems with breaking change

* simplify Generator annotations; use mapping proxy to ensure immutable module constant
  • Loading branch information
ariebovenberg authored Dec 7, 2021
1 parent 15cbb83 commit c961ebf
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 130 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ See [0Ver](https://0ver.org/).
### Bugfixes

- Fixes `__slots__` not being set properly in containers and their base classes
- Fixes patching of containers in pytest plugin not undone after each test

## 0.17.0

Expand Down
4 changes: 4 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ hypothesis==6.30.1
# TODO: Remove this lock when we found and fix the route case.
# See: https://github.com/typlog/sphinx-typlog-theme/issues/22
jinja2==3.0.3

# TODO: Remove this lock when this dependency issue is resolved.
# See: https://github.com/miyakogi/m2r/issues/66
mistune<2.0.0
239 changes: 112 additions & 127 deletions returns/contrib/pytest/plugin.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import inspect
import sys
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from types import FrameType
from types import FrameType, MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, TypeVar, Union
from unittest import mock

import pytest
from typing_extensions import Final, final

if TYPE_CHECKING:
from returns.interfaces.specific.result import ResultLikeN

_ERROR_HANDLERS: Final = (
'lash',
)
_ERRORS_COPIERS: Final = (
'map',
'alt',
)

# We keep track of errors handled by keeping a mapping of <object id>: object.
# If an error is handled, it is in the mapping.
# If it isn't in the mapping, the error is not handled.
Expand All @@ -28,7 +21,7 @@
# Also, the object itself cannot be (in) the key because
# (1) we cannot always assume hashability and
# (2) we need to track the object identity, not its value
_ERRORS_HANDLED: Final[Dict[int, Any]] = {} # noqa: WPS407
_ErrorsHandled = Dict[int, Any]

_FunctionType = TypeVar('_FunctionType', bound=Callable)
_ReturnsResultType = TypeVar(
Expand All @@ -41,7 +34,11 @@
class ReturnsAsserts(object):
"""Class with helpers assertions to check containers."""

__slots__ = ()
__slots__ = ('_errors_handled', )

def __init__(self, errors_handled: _ErrorsHandled) -> None:
"""Constructor for this type."""
self._errors_handled = errors_handled

@staticmethod # noqa: WPS602
def assert_equal( # noqa: WPS602
Expand All @@ -55,10 +52,9 @@ def assert_equal( # noqa: WPS602
from returns.primitives.asserts import assert_equal
assert_equal(first, second, deps=deps, backend=backend)

@staticmethod # noqa: WPS602
def is_error_handled(container) -> bool: # noqa: WPS602
def is_error_handled(self, container) -> bool:
"""Ensures that container has its error handled in the end."""
return id(container) in _ERRORS_HANDLED
return id(container) in self._errors_handled

@staticmethod # noqa: WPS602
@contextmanager
Expand Down Expand Up @@ -86,59 +82,6 @@ def assert_trace( # noqa: WPS602
sys.settrace(old_tracer)


@pytest.fixture(scope='session')
def returns(_patch_containers) -> ReturnsAsserts:
"""Returns our own class with helpers assertions to check containers."""
return ReturnsAsserts()


@pytest.fixture(autouse=True)
def _clear_errors_handled():
"""Ensures the 'errors handled' registry doesn't leak memory."""
yield
_ERRORS_HANDLED.clear()


def pytest_configure(config) -> None:
"""
Hook to be executed on import.
We use it define custom markers.
"""
config.addinivalue_line(
'markers',
(
'returns_lawful: all tests under `check_all_laws` ' +
'is marked this way, ' +
'use `-m "not returns_lawful"` to skip them.'
),
)


@pytest.fixture(scope='session')
def _patch_containers() -> None:
"""
Fixture to add test specifics into our containers.
Currently we inject:
- Error handling state, this is required to test that ``Result``-based
containers do handle errors
Even more things to come!
"""
_patch_error_handling(_ERROR_HANDLERS, _PatchedContainer.error_handler)
_patch_error_handling(_ERRORS_COPIERS, _PatchedContainer.copy_handler)


def _patch_error_handling(methods, patch_handler) -> None:
for container in _PatchedContainer.containers_to_patch():
for method in methods:
original = getattr(container, method, None)
if original:
setattr(container, method, patch_handler(original))


def _trace_function(
trace_type: _ReturnsResultType,
function_to_search: _FunctionType,
Expand Down Expand Up @@ -166,65 +109,107 @@ def _trace_function(
raise _DesiredFunctionFound()


@final
class _PatchedContainer(object):
"""Class with helper methods to patched containers."""

__slots__ = ()

@classmethod
def containers_to_patch(cls) -> tuple:
"""We need this method so coverage will work correctly."""
from returns.context import (
RequiresContextFutureResult,
RequiresContextIOResult,
RequiresContextResult,
)
from returns.future import FutureResult
from returns.io import IOFailure, IOSuccess
from returns.result import Failure, Success

return (
Success,
Failure,
IOSuccess,
IOFailure,
RequiresContextResult,
RequiresContextIOResult,
RequiresContextFutureResult,
FutureResult,
)
class _DesiredFunctionFound(BaseException): # noqa: WPS418
"""Exception to raise when expected function is found."""

@classmethod
def error_handler(cls, original):
if inspect.iscoroutinefunction(original):
async def factory(self, *args, **kwargs):
original_result = await original(self, *args, **kwargs)
_ERRORS_HANDLED[id(original_result)] = original_result
return original_result
else:
def factory(self, *args, **kwargs):
original_result = original(self, *args, **kwargs)
_ERRORS_HANDLED[id(original_result)] = original_result
return original_result
return wraps(original)(factory)

@classmethod
def copy_handler(cls, original):
if inspect.iscoroutinefunction(original):
async def factory(self, *args, **kwargs):
original_result = await original(self, *args, **kwargs)
if id(self) in _ERRORS_HANDLED:
_ERRORS_HANDLED[id(original_result)] = original_result
return original_result
else:
def factory(self, *args, **kwargs):
original_result = original(self, *args, **kwargs)
if id(self) in _ERRORS_HANDLED:
_ERRORS_HANDLED[id(original_result)] = original_result
return original_result
return wraps(original)(factory)

def pytest_configure(config) -> None:
"""
Hook to be executed on import.
class _DesiredFunctionFound(BaseException): # noqa: WPS418
"""Exception to raise when expected function is found."""
We use it define custom markers.
"""
config.addinivalue_line(
'markers',
(
'returns_lawful: all tests under `check_all_laws` ' +
'is marked this way, ' +
'use `-m "not returns_lawful"` to skip them.'
),
)


@pytest.fixture()
def returns() -> Iterator[ReturnsAsserts]:
"""Returns class with helpers assertions to check containers."""
with _spy_error_handling() as errors_handled:
yield ReturnsAsserts(errors_handled)


@contextmanager
def _spy_error_handling() -> Iterator[_ErrorsHandled]:
"""Track error handling of containers."""
errs: _ErrorsHandled = {}
with ExitStack() as cleanup:
for container in _containers_to_patch():
for method, patch in _ERROR_HANDLING_PATCHERS.items():
cleanup.enter_context(mock.patch.object(
container,
method,
patch(getattr(container, method), errs=errs),
))
yield errs


# delayed imports are needed to prevent messing up coverage
def _containers_to_patch() -> tuple:
from returns.context import (
RequiresContextFutureResult,
RequiresContextIOResult,
RequiresContextResult,
)
from returns.future import FutureResult
from returns.io import IOFailure, IOSuccess
from returns.result import Failure, Success

return (
Success,
Failure,
IOSuccess,
IOFailure,
RequiresContextResult,
RequiresContextIOResult,
RequiresContextFutureResult,
FutureResult,
)


def _patched_error_handler(
original: _FunctionType, errs: _ErrorsHandled,
) -> _FunctionType:
if inspect.iscoroutinefunction(original):
async def wrapper(self, *args, **kwargs):
original_result = await original(self, *args, **kwargs)
errs[id(original_result)] = original_result
return original_result
else:
def wrapper(self, *args, **kwargs):
original_result = original(self, *args, **kwargs)
errs[id(original_result)] = original_result
return original_result
return wraps(original)(wrapper) # type: ignore


def _patched_error_copier(
original: _FunctionType, errs: _ErrorsHandled,
) -> _FunctionType:
if inspect.iscoroutinefunction(original):
async def wrapper(self, *args, **kwargs):
original_result = await original(self, *args, **kwargs)
if id(self) in errs:
errs[id(original_result)] = original_result
return original_result
else:
def wrapper(self, *args, **kwargs):
original_result = original(self, *args, **kwargs)
if id(self) in errs:
errs[id(original_result)] = original_result
return original_result
return wraps(original)(wrapper) # type: ignore


_ERROR_HANDLING_PATCHERS: Final = MappingProxyType({
'lash': _patched_error_handler,
'map': _patched_error_copier,
'alt': _patched_error_copier,
})
9 changes: 6 additions & 3 deletions tests/test_contrib/test_pytest/test_plugin_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
RequiresContextResult,
)
from returns.contrib.pytest import ReturnsAsserts
from returns.contrib.pytest.plugin import _ERRORS_HANDLED
from returns.functions import identity
from returns.future import FutureResult
from returns.io import IOFailure, IOSuccess
Expand Down Expand Up @@ -42,13 +41,15 @@ def _under_test(
])
def test_error_handled(returns: ReturnsAsserts, container, kwargs):
"""Demo on how to use ``pytest`` helpers to work with error handling."""
assert not _ERRORS_HANDLED
assert not returns._errors_handled # noqa: WPS437
error_handled = _under_test(container, **kwargs)

assert returns.is_error_handled(error_handled)
assert returns.is_error_handled(error_handled.map(identity))
assert returns.is_error_handled(error_handled.alt(identity))

assert returns._errors_handled # noqa: WPS437


@pytest.mark.parametrize('container', [
Success(1),
Expand All @@ -64,14 +65,16 @@ def test_error_handled(returns: ReturnsAsserts, container, kwargs):
])
def test_error_not_handled(returns: ReturnsAsserts, container):
"""Demo on how to use ``pytest`` helpers to work with error handling."""
assert not _ERRORS_HANDLED
assert not returns._errors_handled # noqa: WPS437
error_handled = _under_test(container)

assert not returns.is_error_handled(container)
assert not returns.is_error_handled(error_handled)
assert not returns.is_error_handled(error_handled.map(identity))
assert not returns.is_error_handled(error_handled.alt(identity))

assert not returns._errors_handled # noqa: WPS437


@pytest.mark.anyio()
@pytest.mark.parametrize('container', [
Expand Down

0 comments on commit c961ebf

Please sign in to comment.