Skip to content

Commit

Permalink
Merge pull request #92 from EmadMokhtar/use-get_serialzier_class
Browse files Browse the repository at this point in the history
Use get_serializer_class for Views without serlaizer_class attribute
  • Loading branch information
ekonstantinidis committed May 25, 2016
2 parents e68b968 + 9155175 commit b529776
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 16 deletions.
33 changes: 19 additions & 14 deletions rest_framework_docs/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,27 @@ def __get_permissions_class__(self):

def __get_serializer_fields__(self):
fields = []
serializer = None

if hasattr(self.callback.cls, 'serializer_class') and hasattr(self.callback.cls.serializer_class, 'get_fields'):
if hasattr(self.callback.cls, 'serializer_class'):
serializer = self.callback.cls.serializer_class
if hasattr(serializer, 'get_fields'):
try:
fields = [{
"name": key,
"type": str(field.__class__.__name__),
"required": field.required
} for key, field in serializer().get_fields().items()]
except KeyError as e:
self.errors = e
fields = []

# FIXME:
# Show more attibutes of `field`?

elif hasattr(self.callback.cls, 'get_serializer_class'):
serializer = self.callback.cls.get_serializer_class(self.pattern.callback.cls())

if hasattr(serializer, 'get_fields'):
try:
fields = [{
"name": key,
"type": str(field.__class__.__name__),
"required": field.required
} for key, field in serializer().get_fields().items()]
except KeyError as e:
self.errors = e
fields = []

# FIXME:
# Show more attibutes of `field`?

return fields

Expand Down
12 changes: 10 additions & 2 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_index_view_with_endpoints(self):
response = self.client.get(reverse('drfdocs'))

self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["endpoints"]), 10)
self.assertEqual(len(response.context["endpoints"]), 11)

# Test the login view
self.assertEqual(response.context["endpoints"][0].name_parent, "accounts")
Expand All @@ -38,8 +38,16 @@ def test_index_view_with_endpoints(self):
self.assertEqual(response.context["endpoints"][0].fields[0]["type"], "CharField")
self.assertTrue(response.context["endpoints"][0].fields[0]["required"])

self.assertEqual(response.context["endpoints"][1].name_parent, "accounts")
self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS'])
self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/")
self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class")
self.assertEqual(len(response.context["endpoints"][1].fields), 2)
self.assertEqual(response.context["endpoints"][1].fields[0]["type"], "CharField")
self.assertTrue(response.context["endpoints"][1].fields[0]["required"])

# The view "OrganisationErroredView" (organisations/(?P<slug>[\w-]+)/errored/) should contain an error.
self.assertEqual(str(response.context["endpoints"][8].errors), "'test_value'")
self.assertEqual(str(response.context["endpoints"][9].errors), "'test_value'")

def test_index_search_with_endpoints(self):
response = self.client.get("%s?search=reset-password" % reverse("drfdocs"))
Expand Down
1 change: 1 addition & 0 deletions tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

accounts_urls = [
url(r'^login/$', views.LoginView.as_view(), name="login"),
url(r'^login2/$', views.LoginWithSerilaizerClassView.as_view(), name="login2"),
url(r'^register/$', views.UserRegistrationView.as_view(), name="register"),
url(r'^reset-password/$', view=views.PasswordResetView.as_view(), name="reset-password"),
url(r'^reset-password/confirm/$', views.PasswordResetConfirmView.as_view(), name="reset-password-confirm"),
Expand Down
21 changes: 21 additions & 0 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,24 @@ def delete(self, request, *args, **kwargs):
class OrganisationErroredView(generics.ListAPIView):

serializer_class = serializers.OrganisationErroredSerializer


class LoginWithSerilaizerClassView(APIView):
"""
A view that allows users to login providing their username and password. Without serializer_class
"""

throttle_classes = ()
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,)

def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})

def get_serializer_class(self):
return AuthTokenSerializer

0 comments on commit b529776

Please sign in to comment.