diff --git a/benches/main.rs b/benches/main.rs index 4f5ba1496..39f8fba16 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -151,7 +151,6 @@ fn list_error_json(bench: &mut Bencher) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); - // println!("error: {}", v.to_string()); assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]"); let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap(); assert_eq!(error_count, 100); @@ -184,7 +183,6 @@ fn list_error_python_input(py: Python<'_>) -> (SchemaValidator, PyObject) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); - // println!("error: {}", v.to_string()); assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]"); let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap(); assert_eq!(error_count, 100); @@ -357,7 +355,6 @@ fn dict_value_error(bench: &mut Bencher) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); - // println!("error: {}", v.to_string()); assert_eq!(v.getattr("title").unwrap().to_string(), "dict[str,constrained-int]"); let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap(); assert_eq!(error_count, 100); @@ -484,7 +481,6 @@ fn typed_dict_deep_error(bench: &mut Bencher) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); - // println!("error: {}", v.to_string()); assert_eq!(v.getattr("title").unwrap().to_string(), "typed-dict"); let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap(); assert_eq!(error_count, 1); diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index c8a3b6da6..9dcc1e9ba 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -10,6 +10,7 @@ from collections.abc import Hashable, Mapping from datetime import date, datetime, time, timedelta from decimal import Decimal +from fractions import Fraction from re import Pattern from typing import TYPE_CHECKING, Any, Callable, Literal, Union @@ -811,6 +812,61 @@ def decimal_schema( serialization=serialization, ) +class FractionSchema(TypedDict, total=False): + type: Required[Literal['decimal']] + le: Fraction + ge: Fraction + lt: Fraction + gt: Fraction + strict: bool + ref: str + metadata: dict[str, Any] + serialization: SerSchema + +def fraction_schema( + *, + le: Fraction | None = None, + ge: Fraction | None = None, + lt: Fraction | None = None, + gt: Fraction | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: dict[str, Any] | None = None, + serialization: SerSchema | None = None, +) -> FractionSchema: + """ + Returns a schema that matches a fraction value, e.g.: + + ```py + from fractions import Fraction + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.fraction_schema(le=0.8, ge=0.2) + v = SchemaValidator(schema) + assert v.validate_python(1, 2) == Fraction(1, 2) + ``` + + Args: + le: The value must be less than or equal to this number + ge: The value must be greater than or equal to this number + lt: The value must be strictly less than this number + gt: The value must be strictly greater than this number + strict: Whether the value should be a float or a value that can be converted to a float + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='fraction', + gt=gt, + ge=ge, + lt=lt, + le=le, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) class ComplexSchema(TypedDict, total=False): type: Required[Literal['complex']] @@ -4111,6 +4167,7 @@ def definition_reference_schema( IntSchema, FloatSchema, DecimalSchema, + FractionSchema, StringSchema, BytesSchema, DateSchema, @@ -4170,6 +4227,7 @@ def definition_reference_schema( 'int', 'float', 'decimal', + 'fraction', 'str', 'bytes', 'date', @@ -4321,6 +4379,8 @@ def definition_reference_schema( 'uuid_version', 'decimal_type', 'decimal_parsing', + 'fraction_type', + 'fraction_parsing', 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', diff --git a/src/errors/types.rs b/src/errors/types.rs index 1b821e23e..3da8e9364 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -433,6 +433,9 @@ error_types! { DecimalWholeDigits { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, + // Fraction errors + FractionType {}, + FractionParsing {}, // Complex errors ComplexType {}, ComplexStrParsing {}, @@ -584,6 +587,8 @@ impl ErrorType { Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", + Self::FractionParsing {..} => "Input should be a valid fraction", + Self::FractionType {..} => "Fraction input should be an integer, float, string or Fraction object", Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 17c9546db..d6309be20 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -115,6 +115,8 @@ pub trait Input<'py>: fmt::Debug { fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch>; + fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch>; + type Dict<'a>: ValidatedDict<'py> where Self: 'a; diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 9d7763c6f..698da7776 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -13,6 +13,7 @@ use crate::input::return_enums::EitherComplex; use crate::lookup_key::{LookupKey, LookupPath}; use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; +use crate::validators::fraction::create_fraction; use crate::validators::{TemporalUnitMode, ValBytesMode}; use super::datetime::{ @@ -217,6 +218,18 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } } + fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch> { + match self { + JsonValue::Float(f) => { + create_fraction(&PyString::new(py, &f.to_string()), self).map(ValidationMatch::strict) + } + JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => { + create_fraction(&self.into_pyobject(py)?, self).map(ValidationMatch::strict) + } + _ => Err(ValError::new(ErrorTypeDefaults::FractionType, self)), + } + } + type Dict<'a> = &'a JsonObject<'data> where @@ -472,6 +485,10 @@ impl<'py> Input<'py> for str { create_decimal(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax) } + fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch> { + create_fraction(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax) + } + type Dict<'a> = Never; #[cfg_attr(has_coverage_attribute, coverage(off))] diff --git a/src/input/input_python.rs b/src/input/input_python.rs index d3a26bfc5..a55001acd 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -3,7 +3,6 @@ use std::str::from_utf8; use pyo3::intern; use pyo3::prelude::*; -use pyo3::sync::PyOnceLock; use pyo3::types::PyType; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, @@ -18,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, use crate::tools::{extract_i64, safe_repr}; use crate::validators::complex::string_to_complex; use crate::validators::decimal::{create_decimal, get_decimal_type}; +use crate::validators::fraction::{create_fraction, get_fraction_type}; use crate::validators::Exactness; use crate::validators::TemporalUnitMode; use crate::validators::ValBytesMode; @@ -48,20 +48,6 @@ use super::{ Input, }; -static FRACTION_TYPE: PyOnceLock> = PyOnceLock::new(); - -pub fn get_fraction_type(py: Python<'_>) -> &Bound<'_, PyType> { - FRACTION_TYPE - .get_or_init(py, || { - py.import("fractions") - .and_then(|fractions_module| fractions_module.getattr("Fraction")) - .unwrap() - .extract() - .unwrap() - }) - .bind(py) -} - pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> { input.as_python().and_then(|any| any.downcast::().ok()) } @@ -290,8 +276,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { float_as_int(self, self.extract::()?) } else if let Ok(decimal) = self.validate_decimal(true, self.py()) { decimal_as_int(self, &decimal.into_inner()) - } else if self.is_instance(get_fraction_type(self.py()))? { - fraction_as_int(self) + } else if let Ok(fraction) = self.validate_fraction(true, self.py()) { + fraction_as_int(self, &fraction.into_inner()) } else if let Ok(float) = self.extract::() { float_as_int(self, float) } else if let Some(enum_val) = maybe_as_enum(self) { @@ -349,6 +335,46 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } + fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch> { + let fraction_type = get_fraction_type(py); + + // Fast path for existing fraction objects + if self.is_exact_instance(fraction_type) { + return Ok(ValidationMatch::exact(self.to_owned().clone())); + } + + // Check for fraction subclasses + if self.is_instance(fraction_type)? { + return Ok(ValidationMatch::lax(self.to_owned().clone())); + } + + if !strict { + if self.is_instance_of::() || (self.is_instance_of::() && !self.is_instance_of::()) + { + // Checking isinstance for str / int / bool is fast compared to fraction / float + return create_fraction(self, self).map(ValidationMatch::lax); + } + + if self.is_instance_of::() { + return create_fraction(self.str()?.as_any(), self).map(ValidationMatch::lax); + } + } + + let error_type = if strict { + ErrorType::IsInstanceOf { + class: fraction_type + .qualname() + .and_then(|name| name.extract()) + .unwrap_or_else(|_| "Fraction".to_owned()), + context: None, + } + } else { + ErrorTypeDefaults::FractionType + }; + + Err(ValError::new(error_type, self)) + } + fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch> { let decimal_type = get_decimal_type(py); diff --git a/src/input/input_string.rs b/src/input/input_string.rs index a635188a8..5e410d231 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -9,6 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath}; use crate::tools::safe_repr; use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; +use crate::validators::fraction::create_fraction; use crate::validators::{TemporalUnitMode, ValBytesMode}; use super::datetime::{ @@ -154,6 +155,13 @@ impl<'py> Input<'py> for StringMapping<'py> { } } + fn validate_fraction(&self, _strict: bool, _py: Python<'py>) -> ValMatch> { + match self { + Self::String(s) => create_fraction(s, self).map(ValidationMatch::strict), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), + } + } + type Dict<'a> = StringMappingDict<'py> where diff --git a/src/input/shared.rs b/src/input/shared.rs index d623c1aeb..45120d639 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -228,22 +228,17 @@ pub fn decimal_as_int<'py>( Ok(EitherInt::Py(numerator)) } -pub fn fraction_as_int<'py>(input: &Bound<'py, PyAny>) -> ValResult> { - #[cfg(Py_3_12)] - let is_integer = input.call_method0("is_integer")?.extract::()?; - #[cfg(not(Py_3_12))] - let is_integer = input.getattr("denominator")?.extract::().is_ok_and(|d| d == 1); - - if is_integer { - #[cfg(Py_3_11)] - let as_int = input.call_method0("__int__"); - #[cfg(not(Py_3_11))] - let as_int = input.call_method0("__trunc__"); - match as_int { - Ok(i) => Ok(EitherInt::Py(i.as_any().to_owned())), - Err(_) => Err(ValError::new(ErrorTypeDefaults::IntType, input)), - } - } else { - Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input)) +pub fn fraction_as_int<'py>( + input: &(impl Input<'py> + ?Sized), + fraction: &Bound<'py, PyAny>, +) -> ValResult> { + let py = fraction.py(); + + let (numerator, denominator) = fraction + .call_method0(intern!(py, "as_integer_ratio"))? + .extract::<(Bound<'_, PyAny>, Bound<'_, PyAny>)>()?; + if denominator.extract::().map_or(true, |d| d != 1) { + return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input)); } + Ok(EitherInt::Py(numerator)) } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 2760e0824..e06be95f7 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -110,6 +110,7 @@ pub(crate) fn infer_to_python_known<'py>( v.into_py_any(py)? } ObType::Decimal => value.to_string().into_py_any(py)?, + ObType::Fraction => value.to_string().into_py_any(py)?, ObType::StrSubclass => PyString::new(py, value.downcast::()?.to_str()?).into(), ObType::Bytes => state .config @@ -368,6 +369,7 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>( type_serializers::float::serialize_f64(v, serializer, state.config.inf_nan_mode) } ObType::Decimal => value.to_string().serialize(serializer), + ObType::Fraction => value.to_string().serialize(serializer), ObType::Str | ObType::StrSubclass => { let py_str = value.downcast::().map_err(py_err_se_err)?; serialize_to_json(serializer) @@ -517,6 +519,7 @@ pub(crate) fn infer_json_key_known<'a, 'py>( } } ObType::Decimal => Ok(Cow::Owned(key.to_string())), + ObType::Fraction => Ok(Cow::Owned(key.to_string())), ObType::Bool => super::type_serializers::simple::bool_json_key(key), ObType::Str | ObType::StrSubclass => key.downcast::()?.to_cow(), ObType::Bytes => state diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 8d291ab3d..3a700e391 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -23,6 +23,7 @@ pub struct ObTypeLookup { dict: usize, // other numeric types decimal_object: Py, + fraction_object: Py, // other string types bytes: usize, bytearray: usize, @@ -71,6 +72,7 @@ impl ObTypeLookup { list: PyList::type_object_raw(py) as usize, dict: PyDict::type_object_raw(py) as usize, decimal_object: py.import("decimal").unwrap().getattr("Decimal").unwrap().unbind(), + fraction_object: py.import("fractions").unwrap().getattr("Fraction").unwrap().unbind(), string: PyString::type_object_raw(py) as usize, bytes: PyBytes::type_object_raw(py) as usize, bytearray: PyByteArray::type_object_raw(py) as usize, @@ -139,6 +141,7 @@ impl ObTypeLookup { ObType::List => self.list == ob_type, ObType::Dict => self.dict == ob_type, ObType::Decimal => self.decimal_object.as_ptr() as usize == ob_type, + ObType::Fraction => self.fraction_object.as_ptr() as usize == ob_type, ObType::StrSubclass => self.string == ob_type && op_value.is_none(), ObType::Tuple => self.tuple == ob_type, ObType::Set => self.set == ob_type, @@ -216,6 +219,8 @@ impl ObTypeLookup { ObType::Dict } else if ob_type == self.decimal_object.as_ptr() as usize { ObType::Decimal + } else if ob_type == self.fraction_object.as_ptr() as usize { + ObType::Fraction } else if ob_type == self.bytes { ObType::Bytes } else if ob_type == self.tuple { @@ -324,6 +329,8 @@ impl ObTypeLookup { ObType::MultiHostUrl } else if value.is_instance(self.decimal_object.bind(py)).unwrap_or(false) { ObType::Decimal + } else if value.is_instance(self.fraction_object.bind(py)).unwrap_or(false) { + ObType::Fraction } else if value.is_instance(self.uuid_object.bind(py)).unwrap_or(false) { ObType::Uuid } else if value.is_instance(self.enum_object.bind(py)).unwrap_or(false) { @@ -381,6 +388,7 @@ pub enum ObType { Float, FloatSubclass, Decimal, + Fraction, // string types Str, StrSubclass, diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 98e91098c..72088a830 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -121,6 +121,7 @@ combined_serializer! { Bool: super::type_serializers::simple::BoolSerializer; Float: super::type_serializers::float::FloatSerializer; Decimal: super::type_serializers::decimal::DecimalSerializer; + Fraction: super::type_serializers::fraction::FractionSerializer; Str: super::type_serializers::string::StrSerializer; Bytes: super::type_serializers::bytes::BytesSerializer; Datetime: super::type_serializers::datetime_etc::DatetimeSerializer; @@ -326,6 +327,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Decimal(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Fraction(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit), diff --git a/src/serializers/type_serializers/fraction.rs b/src/serializers/type_serializers/fraction.rs new file mode 100644 index 000000000..a94b7d2f2 --- /dev/null +++ b/src/serializers/type_serializers/fraction.rs @@ -0,0 +1,86 @@ +use std::borrow::Cow; +use std::sync::Arc; + +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::build_tools::LazyLock; +use crate::definitions::DefinitionsBuilder; +use crate::serializers::infer::{infer_json_key_known, infer_serialize_known, infer_to_python_known}; +use crate::serializers::ob_type::{IsType, ObType}; + +use crate::serializers::SerializationState; + +use super::{ + infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, TypeSerializer, +}; + +#[derive(Debug)] +pub struct FractionSerializer {} + +static FRACTION_SERIALIZER: LazyLock> = + LazyLock::new(|| Arc::new(FractionSerializer {}.into())); + +impl BuildSerializer for FractionSerializer { + const EXPECTED_TYPE: &'static str = "fraction"; + + fn build( + _schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder>, + ) -> PyResult> { + Ok(FRACTION_SERIALIZER.clone()) + } +} + +impl_py_gc_traverse!(FractionSerializer {}); + +impl TypeSerializer for FractionSerializer { + fn to_python<'py>( + &self, + value: &Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + let _py = value.py(); + match state.extra.ob_type_lookup.is_type(value, ObType::Fraction) { + IsType::Exact | IsType::Subclass => infer_to_python_known(ObType::Fraction, value, state), + IsType::False => { + state.warn_fallback_py(self.get_name(), value)?; + infer_to_python(value, state) + } + } + } + + fn json_key<'a, 'py>( + &self, + key: &'a Bound<'py, PyAny>, + state: &mut SerializationState<'_, 'py>, + ) -> PyResult> { + match state.extra.ob_type_lookup.is_type(key, ObType::Fraction) { + IsType::Exact | IsType::Subclass => infer_json_key_known(ObType::Fraction, key, state), + IsType::False => { + state.warn_fallback_py(self.get_name(), key)?; + infer_json_key(key, state) + } + } + } + + fn serde_serialize<'py, S: serde::ser::Serializer>( + &self, + value: &Bound<'py, PyAny>, + serializer: S, + state: &mut SerializationState<'_, 'py>, + ) -> Result { + match state.extra.ob_type_lookup.is_type(value, ObType::Decimal) { + IsType::Exact | IsType::Subclass => infer_serialize_known(ObType::Fraction, value, serializer, state), + IsType::False => { + state.warn_fallback_ser::(self.get_name(), value)?; + infer_serialize(value, serializer, state) + } + } + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index 5fe990382..04315fa06 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -9,6 +9,7 @@ pub mod dict; pub mod enum_; pub mod float; pub mod format; +pub mod fraction; pub mod function; pub mod generator; pub mod json; diff --git a/src/validators/fraction.rs b/src/validators/fraction.rs new file mode 100644 index 000000000..58d2606e6 --- /dev/null +++ b/src/validators/fraction.rs @@ -0,0 +1,159 @@ +use std::sync::Arc; + +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; +use pyo3::sync::PyOnceLock; +use pyo3::types::{IntoPyDict, PyDict, PyString, PyType}; +use pyo3::{prelude::*, PyTypeInfo}; + +use crate::build_tools::is_strict; +use crate::errors::ErrorTypeDefaults; +use crate::errors::ValResult; +use crate::errors::{ErrorType, Number, ToErrorValue, ValError}; +use crate::input::Input; + +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; + +static FRACTION_TYPE: PyOnceLock> = PyOnceLock::new(); + +pub fn get_fraction_type(py: Python<'_>) -> &Bound<'_, PyType> { + FRACTION_TYPE + .get_or_init(py, || { + py.import("fractions") + .and_then(|fraction_module| fraction_module.getattr("Fraction")) + .unwrap() + .extract() + .unwrap() + }) + .bind(py) +} + +fn validate_as_fraction( + py: Python, + schema: &Bound<'_, PyDict>, + key: &Bound<'_, PyString>, +) -> PyResult>> { + match schema.get_item(key)? { + Some(value) => match value.validate_fraction(false, py) { + Ok(v) => Ok(Some(v.into_inner().unbind())), + Err(_) => Err(PyValueError::new_err(format!( + "'{key}' must be coercible to a Fraction instance", + ))), + }, + None => Ok(None), + } +} + +#[derive(Debug, Clone)] +pub struct FractionValidator { + strict: bool, + le: Option>, + lt: Option>, + ge: Option>, + gt: Option>, +} + +impl BuildValidator for FractionValidator { + const EXPECTED_TYPE: &'static str = "fraction"; + fn build( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder>, + ) -> PyResult> { + let py = schema.py(); + + Ok(CombinedValidator::Fraction(Self { + strict: is_strict(schema, config)?, + le: validate_as_fraction(py, schema, intern!(py, "le"))?, + lt: validate_as_fraction(py, schema, intern!(py, "lt"))?, + ge: validate_as_fraction(py, schema, intern!(py, "ge"))?, + gt: validate_as_fraction(py, schema, intern!(py, "gt"))?, + }) + .into()) + } +} + +impl_py_gc_traverse!(FractionValidator { le, lt, ge, gt }); + +impl Validator for FractionValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult> { + let fraction = input.validate_fraction(state.strict_or(self.strict), py)?.unpack(state); + + if let Some(le) = &self.le { + if !fraction.le(le)? { + return Err(ValError::new( + ErrorType::LessThanEqual { + le: Number::String(le.to_string()), + context: Some([("le", le)].into_py_dict(py)?.into()), + }, + input, + )); + } + } + if let Some(lt) = &self.lt { + if !fraction.lt(lt)? { + return Err(ValError::new( + ErrorType::LessThan { + lt: Number::String(lt.to_string()), + context: Some([("lt", lt)].into_py_dict(py)?.into()), + }, + input, + )); + } + } + if let Some(ge) = &self.ge { + if !fraction.ge(ge)? { + return Err(ValError::new( + ErrorType::GreaterThanEqual { + ge: Number::String(ge.to_string()), + context: Some([("ge", ge)].into_py_dict(py)?.into()), + }, + input, + )); + } + } + if let Some(gt) = &self.gt { + if !fraction.gt(gt)? { + return Err(ValError::new( + ErrorType::GreaterThan { + gt: Number::String(gt.to_string()), + context: Some([("gt", gt)].into_py_dict(py)?.into()), + }, + input, + )); + } + } + + Ok(fraction.into()) + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} + +pub(crate) fn create_fraction<'py>(arg: &Bound<'py, PyAny>, input: impl ToErrorValue) -> ValResult> { + let py = arg.py(); + get_fraction_type(py) + .call1((arg,)) + .map_err(|e| handle_fraction_new_error(input, e)) +} + +fn handle_fraction_new_error(input: impl ToErrorValue, error: PyErr) -> ValError { + Python::attach(|py| { + if error.matches(py, PyValueError::type_object(py)).unwrap_or(false) { + ValError::new(ErrorTypeDefaults::FractionParsing, input) + } else if error.matches(py, PyTypeError::type_object(py)).unwrap_or(false) { + ValError::new(ErrorTypeDefaults::FractionType, input) + } else { + // Let ZeroDivisionError and other exceptions bubble up as InternalErr + // which will be shown to the user with the original Python error message + ValError::InternalErr(error) + } + }) +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index adcf1ba55..f7a2baff5 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -38,6 +38,7 @@ mod definitions; mod dict; mod enum_; mod float; +pub(crate) mod fraction; mod frozenset; mod function; mod generator; @@ -592,6 +593,8 @@ fn build_validator_inner( float::FloatBuilder, // decimals decimal::DecimalValidator, + // fractions + fraction::FractionValidator, // tuples tuple::TupleValidator, // list/arrays @@ -767,6 +770,8 @@ pub enum CombinedValidator { ConstrainedFloat(float::ConstrainedFloatValidator), // decimals Decimal(decimal::DecimalValidator), + // fractions + Fraction(fraction::FractionValidator), // lists List(list::ListValidator), // sets - unique lists diff --git a/tests/serializers/test_fraction.py b/tests/serializers/test_fraction.py new file mode 100644 index 000000000..6d06c70e2 --- /dev/null +++ b/tests/serializers/test_fraction.py @@ -0,0 +1,47 @@ +from fractions import Fraction + +import pytest + +from pydantic_core import SchemaSerializer, core_schema + + +def test_fraction(): + v = SchemaSerializer(core_schema.fraction_schema()) + assert v.to_python(Fraction('3 / 4')) == Fraction(3, 4) + assert v.to_python(Fraction(3, 4)) == Fraction(3, 4) + + # check correct casting to int when denominator is 1 + assert v.to_python(Fraction(10, 10), mode='json') == '1' + assert v.to_python(Fraction(1, 10), mode='json') == '1/10' + + assert v.to_json(Fraction(3, 4)) == b'"3/4"' + + +def test_fraction_key(): + v = SchemaSerializer(core_schema.dict_schema(core_schema.fraction_schema(), core_schema.fraction_schema())) + assert v.to_python({Fraction(3, 4): Fraction(1, 10)}) == {Fraction(3, 4): Fraction(1, 10)} + assert v.to_python({Fraction(3, 4): Fraction(1, 10)}, mode='json') == {'3/4': '1/10'} + assert v.to_json({Fraction(3, 4): Fraction(1, 10)}) == b'{"3/4":"1/10"}' + + +@pytest.mark.parametrize( + 'value,expected', + [ + (Fraction(3, 4), '3/4'), + (Fraction(1, 10), '1/10'), + (Fraction(10, 1), '10'), + (Fraction(-5, 2), '-5/2'), + ], +) +def test_fraction_json(value, expected): + v = SchemaSerializer(core_schema.fraction_schema()) + assert v.to_python(value, mode='json') == expected + assert v.to_json(value).decode() == f'"{expected}"' + + +def test_any_fraction_key(): + v = SchemaSerializer(core_schema.dict_schema()) + input_value = {Fraction(3, 4): 1} + + assert v.to_python(input_value, mode='json') == {'3/4': 1} + assert v.to_json(input_value) == b'{"3/4":1}' diff --git a/tests/validators/test_fraction.py b/tests/validators/test_fraction.py new file mode 100644 index 000000000..dd28b4ccc --- /dev/null +++ b/tests/validators/test_fraction.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import json +import re +from fractions import Fraction +from typing import Any + +import pytest +from dirty_equals import IsStr + +from pydantic_core import SchemaError, SchemaValidator, ValidationError +from pydantic_core import core_schema as cs + +from ..conftest import Err, PyAndJson, plain_repr + + +class FractionSubclass(Fraction): + pass + + +@pytest.mark.parametrize( + 'constraint', + ['le', 'lt', 'ge', 'gt'], +) +def test_constraints_schema_validation_error(constraint: str) -> None: + with pytest.raises(SchemaError, match=f"'{constraint}' must be coercible to a Fraction instance"): + SchemaValidator(cs.fraction_schema(**{constraint: 'bad_value'})) + + +def test_constraints_schema_validation() -> None: + val = SchemaValidator(cs.fraction_schema(gt='1')) + with pytest.raises(ValidationError): + val.validate_python('0') + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (0, Fraction(0)), + (1, Fraction(1)), + (42, Fraction(42)), + ('42', Fraction(42)), + ('42.123', Fraction('42.123')), + (42.0, Fraction(42)), + (42.5, Fraction('42.5')), + (1e10, Fraction('1E10')), + (Fraction('42.0'), Fraction(42)), + (Fraction('42.5'), Fraction('42.5')), + (Fraction('1e10'), Fraction('1E10')), + ( + Fraction('123456789123456789123456789.123456789123456789123456789'), + Fraction('123456789123456789123456789.123456789123456789123456789'), + ), + (FractionSubclass('42.0'), Fraction(42)), + (FractionSubclass('42.5'), Fraction('42.5')), + (FractionSubclass('1e10'), Fraction('1E10')), + ( + True, + Err( + 'Fraction input should be an integer, float, string or Fraction object [type=fraction_type, input_value=True, input_type=bool]' + ), + ), + ( + False, + Err( + 'Fraction input should be an integer, float, string or Fraction object [type=fraction_type, input_value=False, input_type=bool]' + ), + ), + ('wrong', Err('Input should be a valid fraction [type=fraction_parsing')), + ( + [1, 2], + Err( + 'Fraction input should be an integer, float, string or Fraction object [type=fraction_type, input_value=[1, 2], input_type=list]' + ), + ), + ], +) +def test_fraction(py_and_json: PyAndJson, input_value, expected): + v = py_and_json({'type': 'fraction'}) + # Fraction types are not JSON serializable + if v.validator_type == 'json' and isinstance(input_value, Fraction): + input_value = str(input_value) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_test(input_value) + else: + output = v.validate_test(input_value) + assert output == expected + assert isinstance(output, Fraction) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (Fraction(0), Fraction(0)), + (Fraction(1), Fraction(1)), + (Fraction(42), Fraction(42)), + (Fraction('42.0'), Fraction('42.0')), + (Fraction('42.5'), Fraction('42.5')), + (42.0, Err('Input should be an instance of Fraction [type=is_instance_of, input_value=42.0, input_type=float]')), + ('42', Err("Input should be an instance of Fraction [type=is_instance_of, input_value='42', input_type=str]")), + (42, Err('Input should be an instance of Fraction [type=is_instance_of, input_value=42, input_type=int]')), + (True, Err('Input should be an instance of Fraction [type=is_instance_of, input_value=True, input_type=bool]')), + ], + ids=repr, +) +def test_fraction_strict_py(input_value, expected): + v = SchemaValidator(cs.fraction_schema(strict=True)) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + output = v.validate_python(input_value) + assert output == expected + assert isinstance(output, Fraction) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (0, Fraction(0)), + (1, Fraction(1)), + (42, Fraction(42)), + ('42.0', Fraction('42.0')), + ('42.5', Fraction('42.5')), + (42.0, Fraction('42.0')), + ('42', Fraction('42')), + ( + True, + Err( + 'Fraction input should be an integer, float, string or Fraction object [type=fraction_type, input_value=True, input_type=bool]' + ), + ), + ], + ids=repr, +) +def test_fraction_strict_json(input_value, expected): + v = SchemaValidator(cs.fraction_schema(strict=True)) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_json(json.dumps(input_value)) + else: + output = v.validate_json(json.dumps(input_value)) + assert output == expected + assert isinstance(output, Fraction) + + +@pytest.mark.parametrize( + 'kwargs,input_value,expected', + [ + ({}, 0, Fraction(0)), + ({}, '123.456', Fraction('123.456')), + ({'ge': 0}, 0, Fraction(0)), + ( + {'ge': 0}, + -0.1, + Err( + 'Input should be greater than or equal to 0 ' + '[type=greater_than_equal, input_value=-0.1, input_type=float]' + ), + ), + ({'gt': 0}, 0.1, Fraction('0.1')), + ({'gt': 0}, 0, Err('Input should be greater than 0 [type=greater_than, input_value=0, input_type=int]')), + ({'le': 0}, 0, Fraction(0)), + ({'le': 0}, -1, Fraction(-1)), + ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), + ({'lt': 0}, 0, Err('Input should be less than 0')), + ({'lt': 0.123456}, 1, Err('Input should be less than 1929/15625')), + ({'lt': 0.123456}, '0.1', Fraction('0.1')), + ], +) +def test_fraction_kwargs(py_and_json: PyAndJson, kwargs: dict[str, Any], input_value, expected): + v = py_and_json({'type': 'fraction', **kwargs}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_test(input_value) + else: + output = v.validate_test(input_value) + assert output == expected + assert isinstance(output, Fraction) + + +def test_union_fraction_py(): + v = SchemaValidator(cs.union_schema(choices=[cs.fraction_schema(strict=True), cs.fraction_schema(gt=0)])) + assert v.validate_python('14') == 14 + assert v.validate_python(Fraction(5)) == 5 + with pytest.raises(ValidationError) as exc_info: + v.validate_python('-5') + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'is_instance_of', + 'loc': ('fraction',), + 'msg': 'Input should be an instance of Fraction', + 'input': '-5', + 'ctx': {'class': 'Fraction'}, + }, + { + 'type': 'greater_than', + 'loc': ('fraction',), + 'msg': 'Input should be greater than 0', + 'input': '-5', + 'ctx': {'gt': Fraction(0)}, + }, + ] + + +def test_union_fraction_json(): + v = SchemaValidator(cs.union_schema(choices=[cs.fraction_schema(strict=True), cs.fraction_schema(gt=0)])) + assert v.validate_json(json.dumps('14')) == 14 + assert v.validate_json(json.dumps('5')) == 5 + + +def test_union_fraction_simple(py_and_json: PyAndJson): + v = py_and_json({'type': 'union', 'choices': [{'type': 'fraction'}, {'type': 'list'}]}) + assert v.validate_test('5') == 5 + with pytest.raises(ValidationError) as exc_info: + v.validate_test('xxx') + + assert exc_info.value.errors(include_url=False) == [ + {'type': 'fraction_parsing', 'loc': ('fraction',), 'msg': 'Input should be a valid fraction', 'input': 'xxx'}, + { + 'type': 'list_type', + 'loc': ('list[any]',), + 'msg': IsStr(regex='Input should be a valid (list|array)'), + 'input': 'xxx', + }, + ] + + +def test_fraction_repr(): + v = SchemaValidator(cs.fraction_schema()) + assert plain_repr(v).startswith( + 'SchemaValidator(title="fraction",validator=Fraction(FractionValidator{strict:false' + ) + v = SchemaValidator(cs.fraction_schema(strict=True)) + assert plain_repr(v).startswith( + 'SchemaValidator(title="fraction",validator=Fraction(FractionValidator{strict:true' + ) + + +@pytest.mark.parametrize('input_value,expected', [(Fraction('1.23'), Fraction('1.23')), (Fraction('1'), Fraction('1.0'))]) +def test_fraction_not_json(input_value, expected): + v = SchemaValidator(cs.fraction_schema()) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + output = v.validate_python(input_value) + assert output == expected + assert isinstance(output, Fraction) + + +def test_fraction_key(py_and_json: PyAndJson): + v = py_and_json({'type': 'dict', 'keys_schema': {'type': 'fraction'}, 'values_schema': {'type': 'int'}}) + assert v.validate_test({'1': 1, '2': 2}) == {Fraction('1'): 1, Fraction('2'): 2} + assert v.validate_test({'1.5': 1, '2.4': 2}) == {Fraction('1.5'): 1, Fraction('2.4'): 2} + if v.validator_type == 'python': + with pytest.raises(ValidationError, match='Input should be an instance of Fraction'): + v.validate_test({'1.5': 1, '2.5': 2}, strict=True) + else: + assert v.validate_test({'1.5': 1, '2.4': 2}, strict=True) == {Fraction('1.5'): 1, Fraction('2.4'): 2} + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + ('NaN', Err("Input should be a valid fraction [type=fraction_parsing, input_value='NaN', input_type=str]")), + ('0.7', Fraction('0.7')), + ( + 'pika', + Err("Input should be a valid fraction [type=fraction_parsing, input_value='pika', input_type=str]"), + ), + ], +) +def test_non_finite_json_values(py_and_json: PyAndJson, input_value, expected): + v = py_and_json({'type': 'fraction'}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_test(input_value) + else: + assert v.validate_test(input_value) == expected + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + # lower e, minus + ('1.0e-12', Fraction('1e-12')), + ('1e-12', Fraction('1e-12')), + ('12e-1', Fraction('12e-1')), + # upper E, minus + ('1.0E-12', Fraction('1e-12')), + ('1E-12', Fraction('1e-12')), + ('12E-1', Fraction('12e-1')), + # lower E, plus + ('1.0e+12', Fraction(' 1e12')), + ('1e+12', Fraction(' 1e12')), + ('12e+1', Fraction(' 12e1')), + # upper E, plus + ('1.0E+12', Fraction(' 1e12')), + ('1E+12', Fraction(' 1e12')), + ('12E+1', Fraction(' 12e1')), + # lower E, unsigned + ('1.0e12', Fraction(' 1e12')), + ('1e12', Fraction(' 1e12')), + ('12e1', Fraction(' 12e1')), + # upper E, unsigned + ('1.0E12', Fraction(' 1e12')), + ('1E12', Fraction(' 1e12')), + ('12E1', Fraction(' 12e1')), + ], +) +def test_validate_scientific_notation_from_json(input_value, expected): + v = SchemaValidator(cs.fraction_schema()) + assert v.validate_json(input_value) == expected + + +def test_str_validation_w_strict() -> None: + s = SchemaValidator(cs.fraction_schema(strict=True)) + + with pytest.raises(ValidationError): + assert s.validate_python('1.23') + + +def test_str_validation_w_lax() -> None: + s = SchemaValidator(cs.fraction_schema(strict=False)) + + assert s.validate_python('1.23') == Fraction('1.23') + + +def test_union_with_str_prefers_str() -> None: + s = SchemaValidator(cs.union_schema([cs.fraction_schema(), cs.str_schema()])) + + assert s.validate_python('1.23') == '1.23' + assert s.validate_python(1.23) == Fraction('1.23')