Skip to content

Commit

Permalink
improve mypy typing #600
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Dec 19, 2021
1 parent ef14a07 commit 2567dc9
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 74 deletions.
33 changes: 18 additions & 15 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import functools
import sys
from collections import defaultdict
from typing import DefaultDict
from typing import Any, DefaultDict

if sys.version_info >= (3, 8):
from typing import Final, Literal, _TypedDictMeta # type: ignore[attr-defined] # noqa: F401
else:
from typing_extensions import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, _TypedDictMeta,
)
from typing_extensions import Final, Literal, _TypedDictMeta # noqa: F401

if sys.version_info >= (3, 10):
from typing import TypeGuard # noqa: F401
else:
from typing_extensions import TypeGuard # noqa: F401


class GeneratorStats:
Expand All @@ -33,19 +36,19 @@ def silence(self):
finally:
self.silent = tmp

def reset(self):
def reset(self) -> None:
self._warn_cache.clear()
self._error_cache.clear()

def emit(self, msg, severity):
def emit(self, msg: str, severity: str) -> None:
assert severity in ['warning', 'error']
msg = _get_current_trace() + str(msg)
cache = self._warn_cache if severity == 'warning' else self._error_cache
if not self.silent and msg not in cache:
print(f'{severity.capitalize()} #{len(cache)}: {msg}', file=sys.stderr)
cache[msg] += 1

def emit_summary(self):
def emit_summary(self) -> None:
if not self.silent and (self._warn_cache or self._error_cache):
print(
f'\nSchema generation summary:\n'
Expand All @@ -58,23 +61,23 @@ def emit_summary(self):
GENERATOR_STATS = GeneratorStats()


def warn(msg):
def warn(msg: str) -> None:
GENERATOR_STATS.emit(msg, 'warning')


def error(msg):
def error(msg: str) -> None:
GENERATOR_STATS.emit(msg, 'error')


def reset_generator_stats():
def reset_generator_stats() -> None:
GENERATOR_STATS.reset()


_TRACES = []


@contextlib.contextmanager
def add_trace_message(trace_message):
def add_trace_message(trace_message: str):
"""
Adds a message to be used as a prefix when emitting warnings and errors.
"""
Expand All @@ -83,11 +86,11 @@ def add_trace_message(trace_message):
_TRACES.pop()


def _get_current_trace():
def _get_current_trace() -> str:
return ''.join(f"{trace}: " for trace in _TRACES if trace)


def has_override(obj, prop):
def has_override(obj, prop: str) -> bool:
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
Expand All @@ -97,15 +100,15 @@ def has_override(obj, prop):
return True


def get_override(obj, prop, default=None):
def get_override(obj, prop: str, default: Any = None) -> Any:
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]


def set_override(obj, prop, value):
def set_override(obj, prop: str, value: Any):
if not hasattr(obj, '_spectacular_annotation'):
obj._spectacular_annotation = {}
elif '_spectacular_annotation' not in obj.__dict__:
Expand Down
2 changes: 1 addition & 1 deletion drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings

Expand Down
38 changes: 22 additions & 16 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import re
import typing
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union

import uritemplate
from django.core import exceptions as django_exceptions
Expand All @@ -13,7 +13,7 @@
from rest_framework.generics import CreateAPIView, GenericAPIView, ListCreateAPIView
from rest_framework.mixins import ListModelMixin
from rest_framework.schemas.inspectors import ViewInspector
from rest_framework.schemas.utils import get_pk_description # type: ignore
from rest_framework.schemas.utils import get_pk_description
from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import get_field_info
from rest_framework.views import APIView
Expand All @@ -36,7 +36,7 @@
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse
from drf_spectacular.utils import Direction, OpenApiParameter, OpenApiResponse, _SerializerType


class AutoSchema(ViewInspector):
Expand Down Expand Up @@ -135,7 +135,7 @@ def _is_create_operation(self):
return True
return False

def get_override_parameters(self):
def get_override_parameters(self) -> List[Union[OpenApiParameter, _SerializerType]]:
""" override this for custom behaviour """
return []

Expand Down Expand Up @@ -189,13 +189,13 @@ def _process_override_parameters(self):
warn(f'could not resolve parameter annotation {parameter}. Skipping.')
return result

def _get_format_parameters(self):
def _get_format_parameters(self) -> List[dict]:
parameters = []
formats = self.map_renderers('format')
if api_settings.URL_FORMAT_OVERRIDE and len(formats) > 1:
parameters.append(build_parameter_type(
name=api_settings.URL_FORMAT_OVERRIDE,
schema=build_basic_type(OpenApiTypes.STR),
schema=build_basic_type(OpenApiTypes.STR), # type: ignore
location=OpenApiParameter.QUERY,
enum=formats
))
Expand Down Expand Up @@ -243,14 +243,14 @@ def dict_helper(parameters):
else:
return list(parameters.values())

def get_description(self):
def get_description(self) -> str: # type: ignore
""" override this for custom behaviour """
action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None)
view_doc = get_doc(self.view.__class__)
action_doc = get_doc(action_or_method)
return action_doc or view_doc

def get_summary(self):
def get_summary(self) -> Optional[str]:
""" override this for custom behaviour """
return None

Expand Down Expand Up @@ -305,21 +305,21 @@ def get_auth(self):
auths.append({})
return auths

def get_request_serializer(self) -> typing.Any:
def get_request_serializer(self) -> Any:
""" override this for custom behaviour """
return self._get_serializer()

def get_response_serializers(self) -> typing.Any:
def get_response_serializers(self) -> Any:
""" override this for custom behaviour """
return self._get_serializer()

def get_tags(self) -> typing.List[str]:
def get_tags(self) -> List[str]:
""" override this for custom behaviour """
tokenized_path = self._tokenize_path()
# use first non-parameter path part as tag
return tokenized_path[:1]

def get_extensions(self) -> typing.Dict[str, typing.Any]:
def get_extensions(self) -> Dict[str, Any]:
return {}

def get_operation_id(self):
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def _get_response_bodies(self):
schema['description'] = _('Unspecified response body')
return {'200': self._get_response_for_code(schema, '200')}

def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]:
def _unwrap_list_serializer(self, serializer, direction) -> Optional[dict]:
if is_field(serializer):
return self._map_serializer_field(serializer, direction)
elif is_basic_serializer(serializer):
Expand Down Expand Up @@ -1277,7 +1277,11 @@ def _get_response_headers_for_code(self, status_code) -> dict:
elif is_serializer(parameter.type):
schema = self.resolve_serializer(parameter.type, 'response').ref
else:
schema = parameter.type
schema = parameter.type # type: ignore

if not schema:
warn(f'response parameter {parameter.name} requires non-empty schema')
continue

if parameter.location not in [OpenApiParameter.HEADER, OpenApiParameter.COOKIE]:
warn(f'incompatible location type ignored for response parameter {parameter.name}')
Expand All @@ -1303,7 +1307,7 @@ def _get_response_headers_for_code(self, status_code) -> dict:

return result

def _get_serializer_name(self, serializer, direction):
def _get_serializer_name(self, serializer, direction: Direction) -> str:
serializer_extension = OpenApiSerializerExtension.get_match(serializer)
if serializer_extension and serializer_extension.get_name():
# library override mechanisms
Expand All @@ -1320,6 +1324,8 @@ def _get_serializer_name(self, serializer, direction):
else:
name = serializer.__class__.__name__

assert name

if name.endswith('Serializer'):
name = name[:-10]

Expand All @@ -1331,7 +1337,7 @@ def _get_serializer_name(self, serializer, direction):

return name

def resolve_serializer(self, serializer, direction) -> ResolvedComponent:
def resolve_serializer(self, serializer: _SerializerType, direction: Direction) -> ResolvedComponent:
assert_basic_serializer(serializer)
serializer = force_instance(serializer)

Expand Down
Loading

0 comments on commit 2567dc9

Please sign in to comment.