Skip to content

Commit

Permalink
refactor: update Pydantic imports (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong authored Dec 28, 2024
1 parent d6a886a commit a7dda85
Showing 1 changed file with 75 additions and 77 deletions.
152 changes: 75 additions & 77 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,54 @@

try:
import pydantic
from pydantic import VERSION, Json
from pydantic import (
VERSION,
AnyHttpUrl,
AnyUrl,
ByteSize,
EmailStr,
FutureDate,
HttpUrl,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
Json,
NameEmail,
NegativeFloat,
NegativeInt,
NonNegativeInt,
NonPositiveFloat,
PastDate,
PaymentCardNumber,
PositiveFloat,
PositiveInt,
SecretBytes,
SecretStr,
StrictBool,
StrictBytes,
StrictFloat,
StrictInt,
StrictStr,
)
from pydantic.fields import FieldInfo
except ImportError as e:
msg = "pydantic is not installed"
raise MissingDependencyException(msg) from e

try:
# pydantic v1
from pydantic import ( # noqa: I001
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
RedisDsn,
)
import pydantic as pydantic_v1
from pydantic import BaseModel as BaseModelV1

# Keep this import last to prevent warnings from pydantic if pydantic v2
# is installed.
from pydantic.color import Color
from pydantic.fields import ( # type: ignore[attr-defined]
DeferredType, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
ModelField, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
Undefined, # pyright: ignore[attr-defined,reportAttributeAccessIssue]
)

# Keep this import last to prevent warnings from pydantic if pydantic v2
# is installed.
from pydantic import PyObject

# prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined
Expand All @@ -71,22 +84,7 @@
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json

from pydantic.v1 import ( # v1 compat imports
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
PyObject,
RedisDsn,
)
import pydantic.v1 as pydantic_v1 # type: ignore[no-redef]
from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined
Expand Down Expand Up @@ -299,10 +297,10 @@ def from_model_field( # pragma: no cover
if unwrap_optional(annotation) in (
AnyUrl,
HttpUrl,
KafkaDsn,
PostgresDsn,
RedisDsn,
AmqpDsn,
pydantic_v1.KafkaDsn,
pydantic_v1.PostgresDsn,
pydantic_v1.RedisDsn,
pydantic_v1.AmqpDsn,
AnyHttpUrl,
):
constraints = {}
Expand Down Expand Up @@ -554,48 +552,48 @@ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
@classmethod
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
mapping: dict[Any, Callable[[], Any]] = {
pydantic.ByteSize: cls.__faker__.pyint,
pydantic.PositiveInt: cls.__faker__.pyint,
pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1,
pydantic.PositiveFloat: cls.__faker__.pyint,
pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0),
pydantic.NonNegativeInt: cls.__faker__.pyint,
pydantic.StrictInt: cls.__faker__.pyint,
pydantic.StrictBool: cls.__faker__.pybool,
pydantic.StrictBytes: partial(create_random_bytes, cls.__random__),
pydantic.StrictFloat: cls.__faker__.pyfloat,
pydantic.StrictStr: cls.__faker__.pystr,
pydantic.EmailStr: cls.__faker__.free_email,
pydantic.NameEmail: cls.__faker__.free_email,
pydantic.Json: cls.__faker__.json,
pydantic.PaymentCardNumber: cls.__faker__.credit_card_number,
pydantic.AnyUrl: cls.__faker__.url,
pydantic.AnyHttpUrl: cls.__faker__.url,
pydantic.HttpUrl: cls.__faker__.url,
pydantic.SecretBytes: partial(create_random_bytes, cls.__random__),
pydantic.SecretStr: cls.__faker__.pystr,
pydantic.IPvAnyAddress: cls.__faker__.ipv4,
pydantic.IPvAnyInterface: cls.__faker__.ipv4,
pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True),
pydantic.PastDate: cls.__faker__.past_date,
pydantic.FutureDate: cls.__faker__.future_date,
ByteSize: cls.__faker__.pyint,
PositiveInt: cls.__faker__.pyint,
NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
NegativeInt: lambda: cls.__faker__.pyint() * -1,
PositiveFloat: cls.__faker__.pyint,
NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0),
NonNegativeInt: cls.__faker__.pyint,
StrictInt: cls.__faker__.pyint,
StrictBool: cls.__faker__.pybool,
StrictBytes: lambda: create_random_bytes(cls.__random__),
StrictFloat: cls.__faker__.pyfloat,
StrictStr: cls.__faker__.pystr,
EmailStr: cls.__faker__.free_email,
NameEmail: cls.__faker__.free_email,
Json: cls.__faker__.json,
PaymentCardNumber: cls.__faker__.credit_card_number,
AnyUrl: cls.__faker__.url,
AnyHttpUrl: cls.__faker__.url,
HttpUrl: cls.__faker__.url,
SecretBytes: lambda: create_random_bytes(cls.__random__),
SecretStr: cls.__faker__.pystr,
IPvAnyAddress: cls.__faker__.ipv4,
IPvAnyInterface: cls.__faker__.ipv4,
IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True),
PastDate: cls.__faker__.past_date,
FutureDate: cls.__faker__.future_date,
}

# v1 only values
mapping.update(
{
PyObject: lambda: "decimal.Decimal",
AmqpDsn: lambda: "amqps://example.com",
KafkaDsn: lambda: "kafka://localhost:9092",
PostgresDsn: lambda: "postgresql://user@localhost",
RedisDsn: lambda: "redis://localhost:6379/0",
FilePath: lambda: Path(realpath(__file__)),
DirectoryPath: lambda: Path(realpath(__file__)).parent,
UUID1: uuid1,
UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
UUID4: cls.__faker__.uuid4,
UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic_v1.PyObject: lambda: "decimal.Decimal",
pydantic_v1.AmqpDsn: lambda: "amqps://example.com",
pydantic_v1.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic_v1.PostgresDsn: lambda: "postgresql://user@localhost",
pydantic_v1.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic_v1.FilePath: lambda: Path(realpath(__file__)),
pydantic_v1.DirectoryPath: lambda: Path(realpath(__file__)).parent,
pydantic_v1.UUID1: uuid1,
pydantic_v1.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic_v1.UUID4: cls.__faker__.uuid4,
pydantic_v1.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)
Expand Down

0 comments on commit a7dda85

Please sign in to comment.