Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ fn list_error_json(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -184,7 +183,6 @@ fn list_error_python_input(py: Python<'_>) -> (SchemaValidator, PyObject) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "list[int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -357,7 +355,6 @@ fn dict_value_error(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "dict[str,constrained-int]");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 100);
Expand Down Expand Up @@ -484,7 +481,6 @@ fn typed_dict_deep_error(bench: &mut Bencher) {
Ok(_) => panic!("unexpectedly valid"),
Err(e) => {
let v = e.value(py);
// println!("error: {}", v.to_string());
assert_eq!(v.getattr("title").unwrap().to_string(), "typed-dict");
let error_count: i64 = v.call_method0("error_count").unwrap().extract().unwrap();
assert_eq!(error_count, 1);
Expand Down
60 changes: 60 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Hashable, Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from fractions import Fraction
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Literal, Union

Expand Down Expand Up @@ -811,6 +812,61 @@ def decimal_schema(
serialization=serialization,
)

class FractionSchema(TypedDict, total=False):
type: Required[Literal['decimal']]
le: Fraction
ge: Fraction
lt: Fraction
gt: Fraction
strict: bool
ref: str
metadata: dict[str, Any]
serialization: SerSchema

def fraction_schema(
*,
le: Fraction | None = None,
ge: Fraction | None = None,
lt: Fraction | None = None,
gt: Fraction | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: SerSchema | None = None,
) -> FractionSchema:
"""
Returns a schema that matches a fraction value, e.g.:

```py
from fractions import Fraction
from pydantic_core import SchemaValidator, core_schema

schema = core_schema.fraction_schema(le=0.8, ge=0.2)
v = SchemaValidator(schema)
assert v.validate_python(1, 2) == Fraction(1, 2)
```

Args:
le: The value must be less than or equal to this number
ge: The value must be greater than or equal to this number
lt: The value must be strictly less than this number
gt: The value must be strictly greater than this number
strict: Whether the value should be a float or a value that can be converted to a float
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='fraction',
gt=gt,
ge=ge,
lt=lt,
le=le,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)

class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
Expand Down Expand Up @@ -4111,6 +4167,7 @@ def definition_reference_schema(
IntSchema,
FloatSchema,
DecimalSchema,
FractionSchema,
StringSchema,
BytesSchema,
DateSchema,
Expand Down Expand Up @@ -4170,6 +4227,7 @@ def definition_reference_schema(
'int',
'float',
'decimal',
'fraction',
'str',
'bytes',
'date',
Expand Down Expand Up @@ -4321,6 +4379,8 @@ def definition_reference_schema(
'uuid_version',
'decimal_type',
'decimal_parsing',
'fraction_type',
'fraction_parsing',
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
Expand Down
5 changes: 5 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Fraction errors
FractionType {},
FractionParsing {},
// Complex errors
ComplexType {},
ComplexStrParsing {},
Expand Down Expand Up @@ -584,6 +587,8 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::FractionParsing {..} => "Input should be a valid fraction",
Self::FractionType {..} => "Fraction input should be an integer, float, string or Fraction object",
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
}
Expand Down
2 changes: 2 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ pub trait Input<'py>: fmt::Debug {

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

type Dict<'a>: ValidatedDict<'py>
where
Self: 'a;
Expand Down
17 changes: 17 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::fraction::create_fraction;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
Expand Down Expand Up @@ -217,6 +218,18 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
JsonValue::Float(f) => {
create_fraction(&PyString::new(py, &f.to_string()), self).map(ValidationMatch::strict)
}
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
create_fraction(&self.into_pyobject(py)?, self).map(ValidationMatch::strict)
}
_ => Err(ValError::new(ErrorTypeDefaults::FractionType, self)),
}
}

type Dict<'a>
= &'a JsonObject<'data>
where
Expand Down Expand Up @@ -472,6 +485,10 @@ impl<'py> Input<'py> for str {
create_decimal(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
}

fn validate_fraction(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
create_fraction(self.into_pyobject(py)?.as_any(), self).map(ValidationMatch::lax)
}

type Dict<'a> = Never;

#[cfg_attr(has_coverage_attribute, coverage(off))]
Expand Down
60 changes: 43 additions & 17 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::str::from_utf8;
use pyo3::intern;
use pyo3::prelude::*;

use pyo3::sync::PyOnceLock;
use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
Expand All @@ -18,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::fraction::{create_fraction, get_fraction_type};
use crate::validators::Exactness;
use crate::validators::TemporalUnitMode;
use crate::validators::ValBytesMode;
Expand Down Expand Up @@ -48,20 +48,6 @@ use super::{
Input,
};

static FRACTION_TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();

pub fn get_fraction_type(py: Python<'_>) -> &Bound<'_, PyType> {
FRACTION_TYPE
.get_or_init(py, || {
py.import("fractions")
.and_then(|fractions_module| fractions_module.getattr("Fraction"))
.unwrap()
.extract()
.unwrap()
})
.bind(py)
}

pub(crate) fn downcast_python_input<'py, T: PyTypeCheck>(input: &(impl Input<'py> + ?Sized)) -> Option<&Bound<'py, T>> {
input.as_python().and_then(|any| any.downcast::<T>().ok())
}
Expand Down Expand Up @@ -290,8 +276,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
decimal_as_int(self, &decimal.into_inner())
} else if self.is_instance(get_fraction_type(self.py()))? {
fraction_as_int(self)
} else if let Ok(fraction) = self.validate_fraction(true, self.py()) {
fraction_as_int(self, &fraction.into_inner())
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Expand Down Expand Up @@ -349,6 +335,46 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
}

fn validate_fraction(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let fraction_type = get_fraction_type(py);

// Fast path for existing fraction objects
if self.is_exact_instance(fraction_type) {
return Ok(ValidationMatch::exact(self.to_owned().clone()));
}

// Check for fraction subclasses
if self.is_instance(fraction_type)? {
return Ok(ValidationMatch::lax(self.to_owned().clone()));
}

if !strict {
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
{
// Checking isinstance for str / int / bool is fast compared to fraction / float
return create_fraction(self, self).map(ValidationMatch::lax);
}

if self.is_instance_of::<PyFloat>() {
return create_fraction(self.str()?.as_any(), self).map(ValidationMatch::lax);
}
}

let error_type = if strict {
ErrorType::IsInstanceOf {
class: fraction_type
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "Fraction".to_owned()),
context: None,
}
} else {
ErrorTypeDefaults::FractionType
};

Err(ValError::new(error_type, self))
}

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);

Expand Down
8 changes: 8 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::fraction::create_fraction;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
Expand Down Expand Up @@ -154,6 +155,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn validate_fraction(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
Self::String(s) => create_fraction(s, self).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
}

type Dict<'a>
= StringMappingDict<'py>
where
Expand Down
29 changes: 12 additions & 17 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,17 @@ pub fn decimal_as_int<'py>(
Ok(EitherInt::Py(numerator))
}

pub fn fraction_as_int<'py>(input: &Bound<'py, PyAny>) -> ValResult<EitherInt<'py>> {
#[cfg(Py_3_12)]
let is_integer = input.call_method0("is_integer")?.extract::<bool>()?;
#[cfg(not(Py_3_12))]
let is_integer = input.getattr("denominator")?.extract::<i64>().is_ok_and(|d| d == 1);

if is_integer {
#[cfg(Py_3_11)]
let as_int = input.call_method0("__int__");
#[cfg(not(Py_3_11))]
let as_int = input.call_method0("__trunc__");
match as_int {
Ok(i) => Ok(EitherInt::Py(i.as_any().to_owned())),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntType, input)),
}
} else {
Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input))
pub fn fraction_as_int<'py>(
input: &(impl Input<'py> + ?Sized),
fraction: &Bound<'py, PyAny>,
) -> ValResult<EitherInt<'py>> {
let py = fraction.py();

let (numerator, denominator) = fraction
.call_method0(intern!(py, "as_integer_ratio"))?
.extract::<(Bound<'_, PyAny>, Bound<'_, PyAny>)>()?;
if denominator.extract::<i64>().map_or(true, |d| d != 1) {
return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input));
}
Ok(EitherInt::Py(numerator))
}
3 changes: 3 additions & 0 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ pub(crate) fn infer_to_python_known<'py>(
v.into_py_any(py)?
}
ObType::Decimal => value.to_string().into_py_any(py)?,
ObType::Fraction => value.to_string().into_py_any(py)?,
ObType::StrSubclass => PyString::new(py, value.downcast::<PyString>()?.to_str()?).into(),
ObType::Bytes => state
.config
Expand Down Expand Up @@ -368,6 +369,7 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
type_serializers::float::serialize_f64(v, serializer, state.config.inf_nan_mode)
}
ObType::Decimal => value.to_string().serialize(serializer),
ObType::Fraction => value.to_string().serialize(serializer),
ObType::Str | ObType::StrSubclass => {
let py_str = value.downcast::<PyString>().map_err(py_err_se_err)?;
serialize_to_json(serializer)
Expand Down Expand Up @@ -517,6 +519,7 @@ pub(crate) fn infer_json_key_known<'a, 'py>(
}
}
ObType::Decimal => Ok(Cow::Owned(key.to_string())),
ObType::Fraction => Ok(Cow::Owned(key.to_string())),
ObType::Bool => super::type_serializers::simple::bool_json_key(key),
ObType::Str | ObType::StrSubclass => key.downcast::<PyString>()?.to_cow(),
ObType::Bytes => state
Expand Down
Loading
Loading