Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generated fields type resolution #565

Merged
merged 11 commits into from
Jun 27, 2024
15 changes: 14 additions & 1 deletion strawberry_django/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

import django
from django.db.models import ForeignKey
from strawberry import LazyType, relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -37,6 +38,11 @@

_QS = TypeVar("_QS", bound="models.QuerySet")

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
else:
GeneratedField = None


class StrawberryDjangoFieldBase(StrawberryField):
def __init__(
Expand Down Expand Up @@ -201,8 +207,15 @@ def resolve_type(
),
self.origin_django_type,
)

is_generated_field = GeneratedField is not None and isinstance(
model_field, GeneratedField
)
field_to_check = (
model_field.output_field if is_generated_field else model_field # type: ignore
)
if is_optional(
model_field,
field_to_check,
self.origin_django_type.is_input,
self.origin_django_type.is_partial,
):
Expand Down
11 changes: 11 additions & 0 deletions strawberry_django/fields/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Union,
)

import django
import strawberry
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db.models import Field, Model, fields
Expand All @@ -38,6 +39,12 @@
IntegerChoicesField = None
TextChoicesField = None

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a comment here that GeneratedField is only available at Django 5.0+ so that we remember to remove it when dropping support for 4.2 in the future?

Another suggestion would be to change this to:

if django.VERSION >= (5, 0):
    from django.db.models import GeneratedField
else:
    GeneratedField = None

This is self-documented, and is the kind of string that I would search for when searching for code that can be dropped/simplified after a version bump

else:
GeneratedField = None


if TYPE_CHECKING:
from strawberry_django.type import StrawberryDjangoDefinition

Expand Down Expand Up @@ -469,6 +476,10 @@ def resolve_model_field_type(
),
)
model_field._strawberry_enum = field_type # type: ignore
# Generated fields
elif GeneratedField is not None and isinstance(model_field, GeneratedField):
model_field_type = type(model_field.output_field) # type: ignore
field_type = field_type_map.get(model_field_type, NotImplemented)
# Every other Field possibility
else:
force_global_id = settings["MAP_AUTO_ID_AS_GLOBAL_ID"]
Expand Down
41 changes: 38 additions & 3 deletions tests/fields/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import decimal
import enum
import uuid
from typing import Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import django
import pytest
import strawberry
from django.conf import settings
Expand All @@ -21,6 +22,12 @@

import strawberry_django
from strawberry_django.fields.field import StrawberryDjangoField
from strawberry_django.type import _process_type # noqa: PLC2701

if django.VERSION >= (5, 0):
from django.db.models import GeneratedField # type: ignore
else:
GeneratedField = None


class FieldTypesModel(models.Model):
Expand All @@ -46,6 +53,24 @@ class FieldTypesModel(models.Model):
url = models.URLField()
uuid = models.UUIDField()
json = models.JSONField()
generated_decimal = (
GeneratedField(
expression=models.F("decimal") * 2,
db_persist=True,
output_field=models.DecimalField(),
)
if GeneratedField is not None
else None
)
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
generated_nullable_decimal = (
GeneratedField(
expression=models.F("decimal") * 2,
db_persist=True,
output_field=models.DecimalField(null=True, blank=True),
)
if GeneratedField is not None
else None
)
foreign_key = models.ForeignKey(
"FieldTypesModel",
blank=True,
Expand Down Expand Up @@ -91,8 +116,7 @@ class Type:
uuid: auto
json: auto

object_definition = get_object_definition(Type, strict=True)
assert [(f.name, f.type) for f in object_definition.fields] == [
expected_types: list[tuple[str, Any]] = [
("id", strawberry.ID),
("boolean", bool),
("char", str),
Expand All @@ -118,6 +142,17 @@ class Type:
("json", JSON),
]

if django.VERSION >= (5, 0):
Type.__annotations__["generated_decimal"] = auto
expected_types.append(("generated_decimal", decimal.Decimal))

Type.__annotations__["generated_nullable_decimal"] = auto
expected_types.append(("generated_nullable_decimal", Optional[decimal.Decimal]))

type_to_test = _process_type(Type, model=FieldTypesModel)
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
object_definition = get_object_definition(type_to_test, strict=True)
assert [(f.name, f.type) for f in object_definition.fields] == expected_types
bellini666 marked this conversation as resolved.
Show resolved Hide resolved


def test_subset_of_fields():
@strawberry_django.type(FieldTypesModel)
Expand Down
Loading