Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _after_postgeneration being invoked twice #164

Merged
merged 9 commits into from
Jun 5, 2022
Merged
18 changes: 17 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,23 @@ Changelog

Unreleased
----------
- The fixture name for registered factories is now determined by the factory name (rather than the model name). This makes factories for builtin types (like ``dict``) easier to use.
- The generated fixture name is now determined by the factory (rather than the model). This makes factories for builtin types (like ``dict``) easier to use. You may need to change your factories to use the ``<model>Factory`` naming convention, or use the ``register(_name=...)`` override. `#163 <https://github.com/pytest-dev/pytest-factoryboy/pull/163>`_

.. code-block:: python

# example
@register
class HTTPHeadersFactory(factory.Factory):
class Meta:
model = dict # no need to use a special dict subclass anymore

Authorization = "Basic Zm9vOmJhcg=="


def test_headers(headers):
assert headers["Authorization"] == "Basic Zm9vOmJhcg=="

- Fix ``Factory._after_postgeneration`` being invoked twice. `#164 <https://github.com/pytest-dev/pytest-factoryboy/pull/164>`_ `#156 <https://github.com/pytest-dev/pytest-factoryboy/issues/156>`_

2.4.0
----------
Expand Down
43 changes: 33 additions & 10 deletions pytest_factoryboy/fixture.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""Factory boy fixture integration."""
from __future__ import annotations

import contextlib
import functools
import sys
from dataclasses import dataclass
from inspect import signature
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Generic,
Iterable,
Iterator,
Mapping,
Type,
TypeVar,
Expand All @@ -24,16 +27,13 @@
import factory.declarations
import factory.enums
import inflection
from factory.declarations import NotProvided
from typing_extensions import ParamSpec, TypeAlias

from .compat import PostGenerationContext
from .fixturegen import create_fixture

if TYPE_CHECKING:
from _pytest.fixtures import SubRequest
from factory.builder import BuildStep
from factory.declarations import PostGeneration, PostGenerationContext

from .plugin import Request as FactoryboyRequest

Expand Down Expand Up @@ -291,6 +291,24 @@ def evaluate(request: SubRequest, value: LazyFixture[T] | T) -> T:
return value.evaluate(request) if isinstance(value, LazyFixture) else value


def noop(*args: Any, **kwargs: Any) -> None:
"""No-op function."""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIP: pass is not required when docstring is provided

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah true that



@contextlib.contextmanager
def disable_method(method: MethodType) -> Iterator[None]:
"""Disable a method."""
klass = method.__self__
method_name = method.__name__
old_method = getattr(klass, method_name)
setattr(klass, method_name, noop)
try:
yield
finally:
setattr(klass, method.__name__, old_method)


def model_fixture(request: SubRequest, factory_name: str) -> Any:
"""Model fixture implementation."""
factoryboy_request: FactoryboyRequest = request.getfixturevalue("factoryboy_request")
Expand Down Expand Up @@ -328,7 +346,10 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
builder = factory.builder.StepBuilder(Factory._meta, kwargs, strategy)
step = factory.builder.BuildStep(builder=builder, sequence=Factory._meta.next_sequence())

instance = Factory(**kwargs)
# FactoryBoy invokes the `_after_postgeneration` method, but we will instead call it manually later,
# once we are able to evaluate all the related fixtures.
with disable_method(Factory._after_postgeneration):
instance = Factory(**kwargs)

# Cache the instance value on pytest level so that the fixture can be resolved before the return
request._fixturedef.cached_result = (instance, 0, None)
Expand Down Expand Up @@ -360,7 +381,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
# that `value_provided` should be falsy
postgen_value = evaluate(request, request.getfixturevalue(argname))
postgen_context = PostGenerationContext(
value_provided=(postgen_value is not NotProvided),
value_provided=(postgen_value is not factory.declarations.NotProvided),
value=postgen_value,
extra=extra,
)
Expand All @@ -369,7 +390,8 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
)
factoryboy_request.defer(deferred)

# Try to evaluate as much post-generation dependencies as possible
# Try to evaluate as much post-generation dependencies as possible.
# This will finally invoke Factory._after_postgeneration, which was previously disabled
factoryboy_request.evaluate(request)
return instance

Expand Down Expand Up @@ -397,12 +419,12 @@ def deferred_impl(request: SubRequest) -> Any:


def make_deferred_postgen(
step: BuildStep,
step: factory.builder.BuildStep,
factory_class: FactoryType,
fixture: str,
instance: Any,
attr: str,
declaration: PostGeneration,
declaration: factory.declarations.PostGenerationDeclaration,
context: PostGenerationContext,
) -> DeferredFunction:
"""Make deferred function for the post-generation declaration.
Expand All @@ -412,6 +434,7 @@ def make_deferred_postgen(
:param fixture: Object fixture name e.g. "author".
:param instance: Parent object instance.
:param attr: Declaration attribute name e.g. "register_user".
:param declaration: Post-generation declaration.
:param context: Post-generation declaration context.

:note: Deferred function name results in "author__register_user".
Expand Down Expand Up @@ -445,9 +468,9 @@ def subfactory_fixture(request: SubRequest, factory_class: FactoryType) -> Any:
return request.getfixturevalue(fixture)


def get_caller_locals(depth: int = 2) -> dict[str, Any]:
def get_caller_locals(depth: int = 0) -> dict[str, Any]:
"""Get the local namespace of the caller frame."""
return sys._getframe(depth).f_locals
return sys._getframe(depth + 2).f_locals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change is required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, but it makes more sense, so that the caller can specify if it wants the locals from the caller (from its POV) or levels from its POV



class LazyFixture(Generic[T]):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_postgen_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ def test_postgenerationmethodcall_fixture(foo: Foo):
assert foo.number == 456


class TestPostgenerationCalledOnce:
@register(_name="collector")
class CollectorFactory(factory.Factory):
class Meta:
model = dict

foo = factory.PostGeneration(lambda *args, **kwargs: 42)

@classmethod
def _after_postgeneration(
cls, obj: dict[str, Any], create: bool, results: dict[str, Any] | None = None
) -> None:
obj.setdefault("_after_postgeneration_calls", []).append((obj, create, results))

def test_postgeneration_called_once(self, request):
"""Test that ``_after_postgeneration`` is called only once."""
foo = request.getfixturevalue("collector")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not add collector as a test's function argument to let pytest do the job of getting the fixture value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just in case there is an error in the setup, it's nicer to have the tests error out at runtime, rather than at collection time (especially if doing TDD), as it would prevent you from running other tests

calls = foo["_after_postgeneration_calls"]
assert len(calls) == 1
[[obj, create, results]] = calls

assert obj is foo
assert create is True
assert isinstance(results, dict)
assert results["foo"] == 42


@dataclass
class Ordered:
value: str | None = None
Expand Down