From ed5a999ace76d967f96731ccc799c42cdc72f426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arttu=20Per=C3=A4l=C3=A4?= Date: Tue, 19 Sep 2023 08:10:35 +0300 Subject: [PATCH] Fix schema generation for nested serializers (#1177) * Fix Serializer schema generation when used as a ListField child * Fix Serializer schema generation when used in another serializer --- CHANGELOG.md | 6 +++ example/factories.py | 22 ++++++++++ example/migrations/0013_questionnaire.py | 28 +++++++++++++ .../migrations/0014_questionnaire_metadata.py | 18 +++++++++ example/models.py | 6 +++ example/serializers.py | 20 ++++++++++ example/tests/conftest.py | 2 + example/tests/test_openapi.py | 40 +++++++++++++++++++ example/tests/test_serializers.py | 33 +++++++++++++++ example/urls.py | 2 + example/urls_test.py | 2 + example/views.py | 7 ++++ rest_framework_json_api/schemas/openapi.py | 8 ++++ 13 files changed, 194 insertions(+) create mode 100644 example/migrations/0013_questionnaire.py create mode 100644 example/migrations/0014_questionnaire_metadata.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 243f13b3..7a3b4a57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Note that in line with [Django REST framework policy](https://www.django-rest-framework.org/topics/release-notes/), any parts of the framework not mentioned in the documentation should generally be considered private API, and may be subject to change. +## [Unreleased] + +### Fixed + +* Fixed OpenAPI schema generation for `Serializer` when used inside another `Serializer` or as a child of `ListField`. + ## [6.1.0] - 2023-08-25 ### Added diff --git a/example/factories.py b/example/factories.py index 4ca1e0b1..37340df4 100644 --- a/example/factories.py +++ b/example/factories.py @@ -12,6 +12,7 @@ Company, Entry, ProjectType, + Questionnaire, ResearchProject, TaggedItem, ) @@ -140,3 +141,24 @@ def future_projects(self, create, extracted, **kwargs): if extracted: for project in extracted: self.future_projects.add(project) + + +class QuestionnaireFactory(factory.django.DjangoModelFactory): + class Meta: + model = Questionnaire + + name = factory.LazyAttribute(lambda x: faker.text()) + questions = [ + { + "text": "What is your name?", + "required": True, + }, + { + "text": "What is your quest?", + "required": False, + }, + { + "text": "What is the air-speed velocity of an unladen swallow?", + }, + ] + metadata = {"author": "Bridgekeeper"} diff --git a/example/migrations/0013_questionnaire.py b/example/migrations/0013_questionnaire.py new file mode 100644 index 00000000..0a3b7cd2 --- /dev/null +++ b/example/migrations/0013_questionnaire.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.5 on 2023-09-07 02:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("example", "0012_author_full_name"), + ] + + operations = [ + migrations.CreateModel( + name="Questionnaire", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100)), + ("questions", models.JSONField()), + ], + ), + ] diff --git a/example/migrations/0014_questionnaire_metadata.py b/example/migrations/0014_questionnaire_metadata.py new file mode 100644 index 00000000..e320ebb8 --- /dev/null +++ b/example/migrations/0014_questionnaire_metadata.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.5 on 2023-09-12 07:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("example", "0013_questionnaire"), + ] + + operations = [ + migrations.AddField( + model_name="questionnaire", + name="metadata", + field=models.JSONField(default={}), + preserve_default=False, + ), + ] diff --git a/example/models.py b/example/models.py index 8fc86c22..35a15a8e 100644 --- a/example/models.py +++ b/example/models.py @@ -180,3 +180,9 @@ class Company(models.Model): def __str__(self): return self.name + + +class Questionnaire(models.Model): + name = models.CharField(max_length=100) + questions = models.JSONField() + metadata = models.JSONField() diff --git a/example/serializers.py b/example/serializers.py index 3d94e6cc..94fe8556 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -18,6 +18,7 @@ LabResults, Project, ProjectType, + Questionnaire, ResearchProject, TaggedItem, ) @@ -421,3 +422,22 @@ class CompanySerializer(serializers.ModelSerializer): class Meta: model = Company fields = "__all__" + + +class QuestionSerializer(serializers.Serializer): + text = serializers.CharField() + required = serializers.BooleanField(default=False) + + +class QuestionnaireMetadataSerializer(serializers.Serializer): + author = serializers.CharField() + producer = serializers.CharField(default=None) + + +class QuestionnaireSerializer(serializers.ModelSerializer): + questions = serializers.ListField(child=QuestionSerializer()) + metadata = QuestionnaireMetadataSerializer() + + class Meta: + model = Questionnaire + fields = ("name", "questions", "metadata") diff --git a/example/tests/conftest.py b/example/tests/conftest.py index 22ab6bd1..6e4b05ba 100644 --- a/example/tests/conftest.py +++ b/example/tests/conftest.py @@ -12,6 +12,7 @@ CommentFactory, CompanyFactory, EntryFactory, + QuestionnaireFactory, ResearchProjectFactory, TaggedItemFactory, ) @@ -27,6 +28,7 @@ register(ArtProjectFactory) register(ResearchProjectFactory) register(CompanyFactory) +register(QuestionnaireFactory) @pytest.fixture diff --git a/example/tests/test_openapi.py b/example/tests/test_openapi.py index 5710da2a..2333dd6a 100644 --- a/example/tests/test_openapi.py +++ b/example/tests/test_openapi.py @@ -125,6 +125,46 @@ def test_schema_id_field(): assert "id" not in company_properties["attributes"]["properties"] +def test_schema_subserializers(): + """Schema for child Serializers reflects the actual response structure.""" + patterns = [ + re_path( + "^questionnaires/?$", views.QuestionnaireViewset.as_view({"get": "list"}) + ), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request("/") + schema = generator.get_schema(request=request) + + assert { + "type": "object", + "properties": { + "metadata": { + "type": "object", + "properties": { + "author": {"type": "string"}, + "producer": {"type": "string"}, + }, + "required": ["author"], + }, + "questions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "required": {"type": "boolean", "default": False}, + }, + "required": ["text"], + }, + }, + "name": {"type": "string", "maxLength": 100}, + }, + "required": ["name", "questions", "metadata"], + } == schema["components"]["schemas"]["Questionnaire"]["properties"]["attributes"] + + def test_schema_parameters_include(): """Include paramater is only used when serializer defines included_serializers.""" patterns = [ diff --git a/example/tests/test_serializers.py b/example/tests/test_serializers.py index 37f50b53..9ad487e5 100644 --- a/example/tests/test_serializers.py +++ b/example/tests/test_serializers.py @@ -224,6 +224,39 @@ def test_model_serializer_with_implicit_fields(self, comment, client): assert response.status_code == 200 assert expected == response.json() + def test_model_serializer_with_subserializers(self, questionnaire, client): + expected = { + "data": { + "type": "questionnaires", + "id": str(questionnaire.pk), + "attributes": { + "name": questionnaire.name, + "questions": [ + { + "text": "What is your name?", + "required": True, + }, + { + "text": "What is your quest?", + "required": False, + }, + { + "text": "What is the air-speed velocity of an unladen swallow?", + "required": False, + }, + ], + "metadata": {"author": "Bridgekeeper", "producer": None}, + }, + }, + } + + response = client.get( + reverse("questionnaire-detail", kwargs={"pk": questionnaire.pk}) + ) + + assert response.status_code == 200 + assert expected == response.json() + class TestPolymorphicModelSerializer(TestCase): def setUp(self): diff --git a/example/urls.py b/example/urls.py index 3d1cf2fa..413d058d 100644 --- a/example/urls.py +++ b/example/urls.py @@ -19,6 +19,7 @@ NonPaginatedEntryViewSet, ProjectTypeViewset, ProjectViewset, + QuestionnaireViewset, ) router = routers.DefaultRouter(trailing_slash=False) @@ -32,6 +33,7 @@ router.register(r"projects", ProjectViewset) router.register(r"project-types", ProjectTypeViewset) router.register(r"lab-results", LabResultViewSet) +router.register(r"questionnaires", QuestionnaireViewset) urlpatterns = [ path("", include(router.urls)), diff --git a/example/urls_test.py b/example/urls_test.py index 92802a81..bb8fbecf 100644 --- a/example/urls_test.py +++ b/example/urls_test.py @@ -20,6 +20,7 @@ NonPaginatedEntryViewSet, ProjectTypeViewset, ProjectViewset, + QuestionnaireViewset, ) router = routers.DefaultRouter(trailing_slash=False) @@ -38,6 +39,7 @@ router.register(r"projects", ProjectViewset) router.register(r"project-types", ProjectTypeViewset) router.register(r"lab-results", LabResultViewSet) +router.register(r"questionnaires", QuestionnaireViewset) # for the old tests router.register(r"identities", Identity) diff --git a/example/views.py b/example/views.py index b0d92811..9c949684 100644 --- a/example/views.py +++ b/example/views.py @@ -29,6 +29,7 @@ LabResults, Project, ProjectType, + Questionnaire, ) from example.serializers import ( AuthorDetailSerializer, @@ -43,6 +44,7 @@ LabResultsSerializer, ProjectSerializer, ProjectTypeSerializer, + QuestionnaireSerializer, ) HTTP_422_UNPROCESSABLE_ENTITY = 422 @@ -292,3 +294,8 @@ class LabResultViewSet(ReadOnlyModelViewSet): "__all__": [], "author": ["author__bio", "author__entries"], } + + +class QuestionnaireViewset(ModelViewSet): + queryset = Questionnaire.objects.all() + serializer_class = QuestionnaireSerializer diff --git a/rest_framework_json_api/schemas/openapi.py b/rest_framework_json_api/schemas/openapi.py index 52f08da6..c0d9ca3a 100644 --- a/rest_framework_json_api/schemas/openapi.py +++ b/rest_framework_json_api/schemas/openapi.py @@ -681,6 +681,14 @@ def map_serializer(self, serializer): and 'links'. """ # TODO: remove attributes, etc. for relationshipView?? + if isinstance( + serializer.parent, (serializers.ListField, serializers.BaseSerializer) + ): + # Return plain non-JSON:API serializer schema for serializers nested inside + # a Serializer or a ListField, as those don't use the full JSON:API + # serializer schemas. + return super().map_serializer(serializer) + required = [] attributes = {} relationships_required = []