diff --git a/strawberry_django/fields/base.py b/strawberry_django/fields/base.py index d19543ca..7c62e42a 100644 --- a/strawberry_django/fields/base.py +++ b/strawberry_django/fields/base.py @@ -38,7 +38,7 @@ _QS = TypeVar("_QS", bound="models.QuerySet") try: - from django.db.models.fields.generated import GeneratedField + from django.db.models import GeneratedField # type: ignore except ImportError: GeneratedField = None @@ -207,7 +207,7 @@ def resolve_type( self.origin_django_type, ) if is_optional( - model_field.output_field + model_field.output_field # type: ignore if GeneratedField is not None and isinstance(model_field, GeneratedField) else model_field, diff --git a/strawberry_django/fields/types.py b/strawberry_django/fields/types.py index 6a404016..bae6234f 100644 --- a/strawberry_django/fields/types.py +++ b/strawberry_django/fields/types.py @@ -39,7 +39,7 @@ TextChoicesField = None try: - from django.db.models.fields.generated import GeneratedField + from django.db.models import GeneratedField # type: ignore except ImportError: GeneratedField = None @@ -476,7 +476,7 @@ def resolve_model_field_type( model_field._strawberry_enum = field_type # type: ignore # Generated fields elif GeneratedField is not None and isinstance(model_field, GeneratedField): - model_field_type = type(model_field.output_field) + model_field_type = type(model_field.output_field) # type: ignore field_type = field_type_map.get(model_field_type, NotImplemented) # Every other Field possibility else: diff --git a/tests/fields/test_types.py b/tests/fields/test_types.py index 364fc1e1..42d35045 100644 --- a/tests/fields/test_types.py +++ b/tests/fields/test_types.py @@ -2,7 +2,7 @@ import decimal import enum import uuid -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast import pytest import strawberry @@ -23,6 +23,11 @@ from strawberry_django.fields.field import StrawberryDjangoField from strawberry_django.type import _process_type +try: + from django.db.models import GeneratedField # type: ignore +except ImportError: + GeneratedField = None + class FieldTypesModel(models.Model): boolean = models.BooleanField() @@ -48,21 +53,21 @@ class FieldTypesModel(models.Model): uuid = models.UUIDField() json = models.JSONField() generated_decimal = ( - models.GeneratedField( + GeneratedField( expression=models.F("decimal") * 2, db_persist=True, output_field=models.DecimalField(), ) - if hasattr(models, "GeneratedField") + if GeneratedField is not None else None ) generated_nullable_decimal = ( - models.GeneratedField( + GeneratedField( expression=models.F("decimal") * 2, db_persist=True, output_field=models.DecimalField(null=True, blank=True), ) - if hasattr(models, "GeneratedField") + if GeneratedField is not None else None ) foreign_key = models.ForeignKey( @@ -110,7 +115,7 @@ class Type: uuid: auto json: auto - expected_types = [ + expected_types: list[tuple[str, Any]] = [ ("id", strawberry.ID), ("boolean", bool), ("char", str),