Skip to content

Commit

Permalink
Generated fields type resolution (#565)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Thiago Bellini Ribeiro <[email protected]>
  • Loading branch information
3 people authored Jun 27, 2024
1 parent 7b8a6a2 commit d9b288c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
15 changes: 14 additions & 1 deletion strawberry_django/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

import django
from django.db.models import ForeignKey
from strawberry import LazyType, relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -37,6 +38,11 @@

_QS = TypeVar("_QS", bound="models.QuerySet")

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
else:
GeneratedField = None


class StrawberryDjangoFieldBase(StrawberryField):
def __init__(
Expand Down Expand Up @@ -201,8 +207,15 @@ def resolve_type(
),
self.origin_django_type,
)

is_generated_field = GeneratedField is not None and isinstance(
model_field, GeneratedField
)
field_to_check = (
model_field.output_field if is_generated_field else model_field # type: ignore
)
if is_optional(
model_field,
field_to_check,
self.origin_django_type.is_input,
self.origin_django_type.is_partial,
):
Expand Down
11 changes: 11 additions & 0 deletions strawberry_django/fields/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Union,
)

import django
import strawberry
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db.models import Field, Model, fields
Expand All @@ -38,6 +39,12 @@
IntegerChoicesField = None
TextChoicesField = None

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
else:
GeneratedField = None


if TYPE_CHECKING:
from strawberry_django.type import StrawberryDjangoDefinition

Expand Down Expand Up @@ -469,6 +476,10 @@ def resolve_model_field_type(
),
)
model_field._strawberry_enum = field_type # type: ignore
# Generated fields
elif GeneratedField is not None and isinstance(model_field, GeneratedField):
model_field_type = type(model_field.output_field) # type: ignore
field_type = field_type_map.get(model_field_type, NotImplemented)
# Every other Field possibility
else:
force_global_id = settings["MAP_AUTO_ID_AS_GLOBAL_ID"]
Expand Down
41 changes: 38 additions & 3 deletions tests/fields/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import decimal
import enum
import uuid
from typing import Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import django
import pytest
import strawberry
from django.conf import settings
Expand All @@ -21,6 +22,12 @@

import strawberry_django
from strawberry_django.fields.field import StrawberryDjangoField
from strawberry_django.type import _process_type # noqa: PLC2701

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
else:
GeneratedField = None


class FieldTypesModel(models.Model):
Expand All @@ -46,6 +53,24 @@ class FieldTypesModel(models.Model):
url = models.URLField()
uuid = models.UUIDField()
json = models.JSONField()
generated_decimal = (
GeneratedField(
expression=models.F("decimal") * 2,
db_persist=True,
output_field=models.DecimalField(),
)
if GeneratedField is not None
else None
)
generated_nullable_decimal = (
GeneratedField(
expression=models.F("decimal") * 2,
db_persist=True,
output_field=models.DecimalField(null=True, blank=True),
)
if GeneratedField is not None
else None
)
foreign_key = models.ForeignKey(
"FieldTypesModel",
blank=True,
Expand Down Expand Up @@ -91,8 +116,7 @@ class Type:
uuid: auto
json: auto

object_definition = get_object_definition(Type, strict=True)
assert [(f.name, f.type) for f in object_definition.fields] == [
expected_types: list[tuple[str, Any]] = [
("id", strawberry.ID),
("boolean", bool),
("char", str),
Expand All @@ -118,6 +142,17 @@ class Type:
("json", JSON),
]

if django.VERSION >= (5, 0):
Type.__annotations__["generated_decimal"] = auto
expected_types.append(("generated_decimal", decimal.Decimal))

Type.__annotations__["generated_nullable_decimal"] = auto
expected_types.append(("generated_nullable_decimal", Optional[decimal.Decimal]))

type_to_test = _process_type(Type, model=FieldTypesModel)
object_definition = get_object_definition(type_to_test, strict=True)
assert [(f.name, f.type) for f in object_definition.fields] == expected_types


def test_subset_of_fields():
@strawberry_django.type(FieldTypesModel)
Expand Down

0 comments on commit d9b288c

Please sign in to comment.