From 604fb7d507b5a4a8bd194b0fb546ad90f894ed60 Mon Sep 17 00:00:00 2001 From: Tarrailt <3165388245@qq.com> Date: Fri, 31 Jan 2025 23:05:47 +0800 Subject: [PATCH] :sparkles: model_validator & field_validator for pydantic compat --- nonebot/compat.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/nonebot/compat.py b/nonebot/compat.py index a13a88692e29..b5e20fc36dbd 100644 --- a/nonebot/compat.py +++ b/nonebot/compat.py @@ -45,6 +45,7 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... __all__ = ( "DEFAULT_CONFIG", + "PYDANTIC_V2", "ConfigDict", "FieldInfo", "ModelField", @@ -54,9 +55,11 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... "TypeAdapter", "custom_validation", "extract_field_info", + "field_validator", "model_config", "model_dump", "model_fields", + "model_validator", "type_validate_json", "type_validate_python", ) @@ -70,6 +73,8 @@ def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ... if PYDANTIC_V2: # pragma: pydantic-v2 from pydantic import GetCoreSchemaHandler from pydantic import TypeAdapter as TypeAdapter + from pydantic import field_validator as field_validator + from pydantic import model_validator as model_validator from pydantic._internal._repr import display_as_type from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic_core import CoreSchema, core_schema @@ -254,7 +259,7 @@ def custom_validation(class_: type["CVC"]) -> type["CVC"]: else: # pragma: pydantic-v1 from pydantic import BaseConfig as PydanticConfig - from pydantic import Extra, parse_obj_as, parse_raw_as + from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic.fields import ModelField as BaseModelField from pydantic.schema import get_annotation_from_field_info @@ -367,6 +372,36 @@ def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: kwargs.update(field_info.extra) return kwargs + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before"], + check_fields: Optional[bool] = None, + ): ... + + @overload + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["after"], + check_fields: Optional[bool] = None, + ): ... + + def field_validator( + field: str, + /, + *fields: str, + mode: Literal["before", "after"], + check_fields: Optional[bool] = None, + ): + if mode == "before": + return validator(field, *fields, pre=True, check_fields=check_fields or True, allow_reuse=True) + else: + return validator(field, *fields, check_fields=check_fields or True, allow_reuse=True) + def model_fields(model: type[BaseModel]) -> list[ModelField]: """Get field list of a model.""" @@ -404,6 +439,18 @@ def model_dump( exclude_none=exclude_none, ) + @overload + def model_validator(*, mode: Literal["before"]): ... + + @overload + def model_validator(*, mode: Literal["after"]): ... + + def model_validator(*, mode: Literal["before", "after"]): + if mode == "before": + return root_validator(pre=True, allow_reuse=True) + else: + return root_validator(skip_on_failure=True, allow_reuse=True) + def type_validate_python(type_: type[T], data: Any) -> T: """Validate data with given type.""" return parse_obj_as(type_, data)