diff --git a/rest_framework_docs/api_endpoint.py b/rest_framework_docs/api_endpoint.py index 89a33f8..520fe0b 100644 --- a/rest_framework_docs/api_endpoint.py +++ b/rest_framework_docs/api_endpoint.py @@ -3,11 +3,13 @@ from django.contrib.admindocs.views import simplify_regex from django.utils.encoding import force_str from rest_framework.serializers import BaseSerializer +from rest_framework_docs.settings import DRFSettings class ApiEndpoint(object): def __init__(self, pattern, parent_pattern=None, drf_router=None): + self.settings = DRFSettings().settings self.drf_router = drf_router self.pattern = pattern self.callback = pattern.callback @@ -67,7 +69,7 @@ def __get_allowed_methods__(self): return viewset_methods + view_methods def __get_docstring__(self): - return inspect.getdoc(self.callback) + return self.settings["VIEW_DESCRIPTION_FUNCTION"](self.callback, html=True) def __get_permissions_class__(self): for perm_class in self.pattern.callback.cls.permission_classes: diff --git a/rest_framework_docs/settings.py b/rest_framework_docs/settings.py index 2853a7b..068ca56 100644 --- a/rest_framework_docs/settings.py +++ b/rest_framework_docs/settings.py @@ -1,11 +1,13 @@ from django.conf import settings +from rest_framework.views import get_view_description class DRFSettings(object): def __init__(self): self.drf_settings = { - "HIDE_DOCS": self.get_setting("HIDE_DOCS") or False + "HIDE_DOCS": self.get_setting("HIDE_DOCS") or False, + "VIEW_DESCRIPTION_FUNCTION": self.get_setting("VIEW_DESCRIPTION_FUNCTION") or get_view_description } def get_setting(self, name): diff --git a/tests/tests.py b/tests/tests.py index 998faee..4482a09 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -33,7 +33,7 @@ def test_index_view_with_endpoints(self): self.assertEqual(response.context["endpoints"][0].name_parent, "accounts") self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS']) self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/") - self.assertEqual(response.context["endpoints"][0].docstring, "A view that allows users to login providing their username and password.") + self.assertEqual(response.context["endpoints"][0].docstring, "
A view that allows users to login providing their username and password.
") self.assertEqual(len(response.context["endpoints"][0].fields), 2) self.assertEqual(response.context["endpoints"][0].fields[0]["type"], "CharField") self.assertTrue(response.context["endpoints"][0].fields[0]["required"]) @@ -41,7 +41,7 @@ def test_index_view_with_endpoints(self): self.assertEqual(response.context["endpoints"][1].name_parent, "accounts") self.assertEqual(response.context["endpoints"][1].allowed_methods, ['POST', 'OPTIONS']) self.assertEqual(response.context["endpoints"][1].path, "/accounts/login2/") - self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class") + self.assertEqual(response.context["endpoints"][1].docstring, "A view that allows users to login providing their username and password. Without serializer_class
") self.assertEqual(len(response.context["endpoints"][1].fields), 2) self.assertEqual(response.context["endpoints"][1].fields[0]["type"], "CharField") self.assertTrue(response.context["endpoints"][1].fields[0]["required"])