From c2777ffa6f1fef555cca503d42fe43ba0cc6e7d3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 25 Mar 2023 14:29:25 -0400 Subject: [PATCH] Minimal handling of union types for pydantic (#91) This adds minimal handling and safeguards for Union type when working with pydantic models. --- kor/adapters.py | 48 ++++++++++++++++---- tests/{test_adpaters.py => test_adapters.py} | 39 ++++++++++++++++ 2 files changed, 79 insertions(+), 8 deletions(-) rename tests/{test_adpaters.py => test_adapters.py} (75%) diff --git a/kor/adapters.py b/kor/adapters.py index 4c6b00b..95836b4 100644 --- a/kor/adapters.py +++ b/kor/adapters.py @@ -1,12 +1,29 @@ """Adapters to convert from validation frameworks to Kor internal representation.""" import enum -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, get_origin +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + get_args, + get_origin, +) from pydantic import BaseModel from .nodes import ExtractionSchemaNode, Number, Object, Option, Selection, Text from .validators import PydanticValidator, Validator +# Not going to support dicts or lists since that requires recursive checks. +# May make sense to either drop the internal representation, or properly extend it +# to handle Lists, Unions etc. +# Not worth the effort, until it's clear that folks are using this functionality. +PRIMITIVE_TYPES = {str, float, int, type(None)} + def _translate_pydantic_to_kor( model_class: Type[BaseModel], @@ -43,17 +60,32 @@ def _translate_pydantic_to_kor( type_ = field.type_ field_many = get_origin(field.outer_type_) is list attribute: Union[ExtractionSchemaNode, Selection, "Object"] - if issubclass(type_, BaseModel): - attribute = _translate_pydantic_to_kor( - type_, - description=field_description, + # Precedence matters here since bool is a subclass of int + if get_origin(type_) is Union: + args = get_args(type_) + + if not all(arg in PRIMITIVE_TYPES for arg in args): + raise NotImplementedError( + "Union of non-primitive types not supported. Issue with" + f"field: `{field_name}`. Has type: `{type_}`" + ) + + attribute = Text( + id=field_name, examples=field_examples, + description=field_description, many=field_many, - name=field_name, ) else: - # Precedence matters here since bool is a subclass of int - if issubclass(type_, bool): + if issubclass(type_, BaseModel): + attribute = _translate_pydantic_to_kor( + type_, + description=field_description, + examples=field_examples, + many=field_many, + name=field_name, + ) + elif issubclass(type_, bool): attribute = Text( id=field_name, examples=field_examples, diff --git a/tests/test_adpaters.py b/tests/test_adapters.py similarity index 75% rename from tests/test_adpaters.py rename to tests/test_adapters.py index 0ee27eb..08acca4 100644 --- a/tests/test_adpaters.py +++ b/tests/test_adapters.py @@ -1,6 +1,8 @@ import enum +from typing import Union import pydantic +import pytest from pydantic.fields import Field from kor.adapters import _translate_pydantic_to_kor, from_pydantic @@ -93,6 +95,43 @@ class Toy(pydantic.BaseModel): ) +def test_convert_pydantic_with_union() -> None: + """Test behavior with Union field.""" + + class Toy(pydantic.BaseModel): + """Toy pydantic object.""" + + a: Union[int, float, None] + + node = _translate_pydantic_to_kor(Toy) + assert node == Object( + id="toy", + attributes=[ + Text( + # Any union type of primitives is mapped to a text field for now. + id="a" + ), + ], + ) + + +def test_convert_pydantic_with_complex_union() -> None: + """Test behavior with Union field that has nested pydantic objects.""" + + class Child(pydantic.BaseModel): + """Child pydantic object.""" + + y: str + + class ModelWithComplexUnion(pydantic.BaseModel): + """Model that has a union with a pydantic object.""" + + x: Union[Child, int] + + with pytest.raises(NotImplementedError): + _translate_pydantic_to_kor(ModelWithComplexUnion) + + def test_from_pydantic() -> None: """Test from pydantic function."""