Skip to content

Commit

Permalink
Minimal handling of union types for pydantic (#91)
Browse files Browse the repository at this point in the history
This adds minimal handling and safeguards for Union type when working with pydantic models.
  • Loading branch information
eyurtsev authored Mar 25, 2023
1 parent ce4ed32 commit c2777ff
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
48 changes: 40 additions & 8 deletions kor/adapters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
"""Adapters to convert from validation frameworks to Kor internal representation."""
import enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, get_origin
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel

from .nodes import ExtractionSchemaNode, Number, Object, Option, Selection, Text
from .validators import PydanticValidator, Validator

# Not going to support dicts or lists since that requires recursive checks.
# May make sense to either drop the internal representation, or properly extend it
# to handle Lists, Unions etc.
# Not worth the effort, until it's clear that folks are using this functionality.
PRIMITIVE_TYPES = {str, float, int, type(None)}


def _translate_pydantic_to_kor(
model_class: Type[BaseModel],
Expand Down Expand Up @@ -43,17 +60,32 @@ def _translate_pydantic_to_kor(
type_ = field.type_
field_many = get_origin(field.outer_type_) is list
attribute: Union[ExtractionSchemaNode, Selection, "Object"]
if issubclass(type_, BaseModel):
attribute = _translate_pydantic_to_kor(
type_,
description=field_description,
# Precedence matters here since bool is a subclass of int
if get_origin(type_) is Union:
args = get_args(type_)

if not all(arg in PRIMITIVE_TYPES for arg in args):
raise NotImplementedError(
"Union of non-primitive types not supported. Issue with"
f"field: `{field_name}`. Has type: `{type_}`"
)

attribute = Text(
id=field_name,
examples=field_examples,
description=field_description,
many=field_many,
name=field_name,
)
else:
# Precedence matters here since bool is a subclass of int
if issubclass(type_, bool):
if issubclass(type_, BaseModel):
attribute = _translate_pydantic_to_kor(
type_,
description=field_description,
examples=field_examples,
many=field_many,
name=field_name,
)
elif issubclass(type_, bool):
attribute = Text(
id=field_name,
examples=field_examples,
Expand Down
39 changes: 39 additions & 0 deletions tests/test_adpaters.py → tests/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
from typing import Union

import pydantic
import pytest
from pydantic.fields import Field

from kor.adapters import _translate_pydantic_to_kor, from_pydantic
Expand Down Expand Up @@ -93,6 +95,43 @@ class Toy(pydantic.BaseModel):
)


def test_convert_pydantic_with_union() -> None:
"""Test behavior with Union field."""

class Toy(pydantic.BaseModel):
"""Toy pydantic object."""

a: Union[int, float, None]

node = _translate_pydantic_to_kor(Toy)
assert node == Object(
id="toy",
attributes=[
Text(
# Any union type of primitives is mapped to a text field for now.
id="a"
),
],
)


def test_convert_pydantic_with_complex_union() -> None:
"""Test behavior with Union field that has nested pydantic objects."""

class Child(pydantic.BaseModel):
"""Child pydantic object."""

y: str

class ModelWithComplexUnion(pydantic.BaseModel):
"""Model that has a union with a pydantic object."""

x: Union[Child, int]

with pytest.raises(NotImplementedError):
_translate_pydantic_to_kor(ModelWithComplexUnion)


def test_from_pydantic() -> None:
"""Test from pydantic function."""

Expand Down

0 comments on commit c2777ff

Please sign in to comment.