Skip to content

Commit dbcf79d

Browse files
authored
fix: Generate correct collection size when annotation_types.Len is used (#712)
1 parent d03aa1b commit dbcf79d

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

polyfactory/factories/pydantic_factory.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from polyfactory.field_meta import Constraints, FieldMeta, Null
1717
from polyfactory.utils.deprecation import check_for_deprecated_parameters
1818
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
19-
from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union
19+
from polyfactory.utils.predicates import is_annotated, is_optional, is_safe_subclass, is_union
2020
from polyfactory.utils.types import NoneType
2121
from polyfactory.value_generators.primitives import create_random_bytes
2222

@@ -270,31 +270,37 @@ def from_model_field( # pragma: no cover
270270
else unwrap_new_type(model_field.annotation)
271271
)
272272

273-
constraints = cast(
274-
"Constraints",
275-
{
276-
"ge": getattr(outer_type, "ge", model_field.field_info.ge),
277-
"gt": getattr(outer_type, "gt", model_field.field_info.gt),
278-
"le": getattr(outer_type, "le", model_field.field_info.le),
279-
"lt": getattr(outer_type, "lt", model_field.field_info.lt),
280-
"min_length": (
281-
getattr(outer_type, "min_length", model_field.field_info.min_length)
282-
or getattr(outer_type, "min_items", model_field.field_info.min_items)
283-
),
284-
"max_length": (
285-
getattr(outer_type, "max_length", model_field.field_info.max_length)
286-
or getattr(outer_type, "max_items", model_field.field_info.max_items)
287-
),
288-
"pattern": getattr(outer_type, "regex", model_field.field_info.regex),
289-
"unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items),
290-
"decimal_places": getattr(outer_type, "decimal_places", None),
291-
"max_digits": getattr(outer_type, "max_digits", None),
292-
"multiple_of": getattr(outer_type, "multiple_of", None),
293-
"upper_case": getattr(outer_type, "to_upper", None),
294-
"lower_case": getattr(outer_type, "to_lower", None),
295-
"item_type": getattr(outer_type, "item_type", None),
296-
},
297-
)
273+
# In pydantic v1, we need to check if the annotation is directly annotated to properly extract constraints
274+
# from the metadata, as v1 doesn't automatically propagate constraints like v2 does
275+
annotation_constraints: Constraints = {}
276+
if is_annotated(model_field.annotation):
277+
annotation_metadata = cls.get_constraints_metadata(model_field.annotation)
278+
annotation_constraints = cls.parse_constraints(annotation_metadata) if annotation_metadata else {}
279+
280+
field_info_constraints = {
281+
"ge": getattr(outer_type, "ge", model_field.field_info.ge),
282+
"gt": getattr(outer_type, "gt", model_field.field_info.gt),
283+
"le": getattr(outer_type, "le", model_field.field_info.le),
284+
"lt": getattr(outer_type, "lt", model_field.field_info.lt),
285+
"min_length": (
286+
getattr(outer_type, "min_length", model_field.field_info.min_length)
287+
or getattr(outer_type, "min_items", model_field.field_info.min_items)
288+
),
289+
"max_length": (
290+
getattr(outer_type, "max_length", model_field.field_info.max_length)
291+
or getattr(outer_type, "max_items", model_field.field_info.max_items)
292+
),
293+
"pattern": getattr(outer_type, "regex", model_field.field_info.regex),
294+
"unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items),
295+
"decimal_places": getattr(outer_type, "decimal_places", None),
296+
"max_digits": getattr(outer_type, "max_digits", None),
297+
"multiple_of": getattr(outer_type, "multiple_of", None),
298+
"upper_case": getattr(outer_type, "to_upper", None),
299+
"lower_case": getattr(outer_type, "to_lower", None),
300+
"item_type": getattr(outer_type, "item_type", None),
301+
}
302+
303+
constraints = cast("Constraints", {**field_info_constraints, **annotation_constraints})
298304

299305
# pydantic v1 has constraints set for these values, but we generate them using faker
300306
if unwrap_optional(annotation) in (

polyfactory/value_generators/constrained_collections.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def handle_constrained_collection( # noqa: C901
4242
return collection_type()
4343

4444
min_items = abs(min_items if min_items is not None else (max_items or 0))
45-
max_items = abs(max_items if max_items is not None else min_items + 1)
45+
max_items = abs(max_items) if max_items is not None else min_items
4646

4747
if max_items < min_items:
4848
msg = "max_items must be larger or equal to min_items"
@@ -110,13 +110,13 @@ def handle_constrained_mapping(
110110
return {}
111111

112112
min_items = abs(min_items if min_items is not None else (max_items or 0))
113-
max_items = abs(max_items if max_items is not None else min_items + 1)
113+
max_items = abs(max_items) if max_items is not None else min_items
114114

115115
if max_items < min_items:
116116
msg = "max_items must be larger or equal to min_items"
117117
raise ParameterException(msg)
118118

119-
length = factory.__random__.randint(min_items, max_items) or 1
119+
length = factory.__random__.randint(min_items, max_items)
120120

121121
collection: dict[Any, Any] = {}
122122

tests/test_collection_length.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, get_args
33

44
import pytest
5+
from annotated_types import Len
6+
from typing_extensions import Annotated
57

68
from pydantic import BaseModel
79
from pydantic.dataclasses import dataclass
@@ -12,6 +14,20 @@
1214
MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50))
1315

1416

17+
@pytest.mark.parametrize("type_", (List[int], Dict[int, int]))
18+
@pytest.mark.parametrize("length", range(1, 10))
19+
def test_annotated_type_collection_length(type_: type, length: int) -> None:
20+
class Foo(BaseModel):
21+
foo: Annotated[type_, Len(length)] # type: ignore
22+
23+
class FooFactory(ModelFactory):
24+
__model__ = Foo
25+
26+
for _ in range(10):
27+
foo = FooFactory.build()
28+
assert len(foo.foo) == length, len(foo.foo)
29+
30+
1531
@pytest.mark.parametrize("type_", (List, Set))
1632
@pytest.mark.parametrize(("min_val", "max_val"), MIN_MAX_PARAMETERS)
1733
def test_collection_length_with_list(min_val: int, max_val: int, type_: type) -> None:

0 commit comments

Comments
 (0)