Skip to content

Commit

Permalink
fix: Extract interface definition in optimizer to fix django-polymorp…
Browse files Browse the repository at this point in the history
…hic (#556)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Thiago Bellini Ribeiro <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2024
1 parent 890b109 commit bfc598a
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 15 deletions.
31 changes: 30 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pytest-mock = "^3.5.1"
pytest-snapshot = "^0.9.0"
pytest-watch = "^4.2.0"
ruff = "^0.4.1"
django-polymorphic = "^3.1.0"
setuptools = "^70.0.0"

[tool.poetry.extras]
debug-toolbar = ["django-debug-toolbar"]
Expand Down
40 changes: 26 additions & 14 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from django.db.models.query import QuerySet
from graphql import (
FieldNode,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLOutputType,
GraphQLWrappingType,
Expand Down Expand Up @@ -354,7 +355,7 @@ def _get_prefetch_queryset(
remote_model: type[models.Model],
schema: Schema,
field: StrawberryField,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
field_node: FieldNode,
*,
config: OptimizerConfig | None,
Expand Down Expand Up @@ -402,7 +403,7 @@ def _optimize_prefetch_queryset(
qs: QuerySet[_M],
schema: Schema,
field: StrawberryField,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
field_node: FieldNode,
*,
config: OptimizerConfig | None,
Expand Down Expand Up @@ -499,13 +500,13 @@ def _optimize_prefetch_queryset(

def _get_selections(
info: GraphQLResolveInfo,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
) -> dict[str, list[FieldNode]]:
return collect_sub_fields(
info.schema,
info.fragments,
info.variable_values,
parent_type,
cast(GraphQLObjectType, parent_type),
info.field_nodes,
)

Expand All @@ -514,14 +515,14 @@ def _generate_selection_resolve_info(
info: GraphQLResolveInfo,
field_nodes: list[FieldNode],
return_type: GraphQLOutputType,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
):
field_node = field_nodes[0]
return GraphQLResolveInfo(
field_name=field_node.name.value,
field_nodes=field_nodes,
return_type=return_type,
parent_type=parent_type,
parent_type=cast(GraphQLObjectType, parent_type),
path=info.path.add_key(0).add_key(field_node.name.value, parent_type.name),
schema=info.schema,
fragments=info.fragments,
Expand All @@ -538,7 +539,7 @@ def _get_model_hints(
schema: Schema,
object_definition: StrawberryObjectDefinition,
*,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
info: GraphQLResolveInfo,
config: OptimizerConfig | None = None,
prefix: str = "",
Expand Down Expand Up @@ -794,12 +795,22 @@ def _get_model_hints(
return store


def _get_gql_definition(
schema: Schema,
definition: StrawberryObjectDefinition,
) -> GraphQLInterfaceType | GraphQLObjectType:
if definition.is_interface:
return schema.schema_converter.from_interface(definition)

return schema.schema_converter.from_object(definition)


def _get_model_hints_from_connection(
model: type[models.Model],
schema: Schema,
object_definition: StrawberryObjectDefinition,
*,
parent_type: GraphQLObjectType,
parent_type: GraphQLObjectType | GraphQLInterfaceType,
info: GraphQLResolveInfo,
config: OptimizerConfig | None = None,
prefix: str = "",
Expand Down Expand Up @@ -828,9 +839,11 @@ def _get_model_hints_from_connection(
e_type = e_definition.resolve_generic(
relay.Edge[cast(Type[relay.Node], n_type)],
)
e_gql_definition = schema.schema_converter.from_object(
e_gql_definition = _get_gql_definition(
schema,
get_object_definition(e_type, strict=True),
)
assert isinstance(e_gql_definition, GraphQLObjectType)
e_info = _generate_selection_resolve_info(
info,
edges,
Expand All @@ -842,7 +855,8 @@ def _get_model_hints_from_connection(
if node.name.value != "node":
continue

n_gql_definition = schema.schema_converter.from_object(n_definition)
n_gql_definition = _get_gql_definition(schema, n_definition)
assert isinstance(n_gql_definition, GraphQLObjectType)
n_info = _generate_selection_resolve_info(
info,
nodes,
Expand Down Expand Up @@ -913,9 +927,7 @@ def optimize(
return qs

# Avoid optimizing twice and also modify an already resolved queryset
if (
is_optimized(qs) or qs._result_cache is not None # type: ignore
):
if is_optimized(qs) or qs._result_cache is not None: # type: ignore
return qs

if isinstance(info, Info):
Expand Down Expand Up @@ -952,7 +964,7 @@ def optimize(
object_definitions = [object_definition]

for inner_object_definition in object_definitions:
parent_type = schema.schema_converter.from_object(inner_object_definition)
parent_type = _get_gql_definition(schema, inner_object_definition)
new_store = _get_model_hints(
qs.model,
schema,
Expand Down
1 change: 1 addition & 0 deletions tests/django_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@
[
"tests",
"tests.projects",
"tests.polymorphism",
],
)
Empty file added tests/polymorphism/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/polymorphism/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from django.db import models
from polymorphic.models import PolymorphicModel


class Project(PolymorphicModel):
topic = models.CharField(max_length=30)


class ArtProject(Project):
artist = models.CharField(max_length=30)


class ResearchProject(Project):
supervisor = models.CharField(max_length=30)
35 changes: 35 additions & 0 deletions tests/polymorphism/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List

import strawberry

import strawberry_django
from strawberry_django.optimizer import DjangoOptimizerExtension

from .models import ArtProject, Project, ResearchProject


@strawberry_django.interface(Project)
class ProjectType:
topic: strawberry.auto


@strawberry_django.type(ArtProject)
class ArtProjectType(ProjectType):
artist: strawberry.auto


@strawberry_django.type(ResearchProject)
class ResearchProjectType(ProjectType):
supervisor: strawberry.auto


@strawberry.type
class Query:
projects: List[ProjectType] = strawberry_django.field()


schema = strawberry.Schema(
query=Query,
types=[ArtProjectType, ResearchProjectType],
extensions=[DjangoOptimizerExtension],
)
38 changes: 38 additions & 0 deletions tests/polymorphism/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from .models import ArtProject, ResearchProject
from .schema import schema


@pytest.mark.django_db(transaction=True)
def test_polymorphic_interface_query():
ap = ArtProject.objects.create(topic="Art", artist="Artist")
rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor")

query = """\
query {
projects {
__typename
topic
... on ArtProjectType {
artist
}
... on ResearchProjectType {
supervisor
}
}
}
"""

result = schema.execute_sync(query)
assert not result.errors
assert result.data == {
"projects": [
{"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist},
{
"__typename": "ResearchProjectType",
"topic": rp.topic,
"supervisor": rp.supervisor,
},
]
}

0 comments on commit bfc598a

Please sign in to comment.