From 2567dc9d7051996c03eadfb03ac1030a9c6150a7 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 11 Nov 2021 11:44:45 +0100 Subject: [PATCH] improve mypy typing #600 --- drf_spectacular/drainage.py | 33 ++++++++------ drf_spectacular/generators.py | 2 +- drf_spectacular/openapi.py | 38 +++++++++------- drf_spectacular/plumbing.py | 86 ++++++++++++++++++----------------- drf_spectacular/utils.py | 4 +- requirements/docs.txt | 1 + 6 files changed, 90 insertions(+), 74 deletions(-) diff --git a/drf_spectacular/drainage.py b/drf_spectacular/drainage.py index d329fe95..39640439 100644 --- a/drf_spectacular/drainage.py +++ b/drf_spectacular/drainage.py @@ -2,14 +2,17 @@ import functools import sys from collections import defaultdict -from typing import DefaultDict +from typing import Any, DefaultDict if sys.version_info >= (3, 8): from typing import Final, Literal, _TypedDictMeta # type: ignore[attr-defined] # noqa: F401 else: - from typing_extensions import ( # type: ignore[attr-defined] # noqa: F401 - Final, Literal, _TypedDictMeta, - ) + from typing_extensions import Final, Literal, _TypedDictMeta # noqa: F401 + +if sys.version_info >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 class GeneratorStats: @@ -33,11 +36,11 @@ def silence(self): finally: self.silent = tmp - def reset(self): + def reset(self) -> None: self._warn_cache.clear() self._error_cache.clear() - def emit(self, msg, severity): + def emit(self, msg: str, severity: str) -> None: assert severity in ['warning', 'error'] msg = _get_current_trace() + str(msg) cache = self._warn_cache if severity == 'warning' else self._error_cache @@ -45,7 +48,7 @@ def emit(self, msg, severity): print(f'{severity.capitalize()} #{len(cache)}: {msg}', file=sys.stderr) cache[msg] += 1 - def emit_summary(self): + def emit_summary(self) -> None: if not self.silent and (self._warn_cache or self._error_cache): print( f'\nSchema generation summary:\n' @@ -58,15 +61,15 @@ def emit_summary(self): GENERATOR_STATS = GeneratorStats() -def warn(msg): +def warn(msg: str) -> None: GENERATOR_STATS.emit(msg, 'warning') -def error(msg): +def error(msg: str) -> None: GENERATOR_STATS.emit(msg, 'error') -def reset_generator_stats(): +def reset_generator_stats() -> None: GENERATOR_STATS.reset() @@ -74,7 +77,7 @@ def reset_generator_stats(): @contextlib.contextmanager -def add_trace_message(trace_message): +def add_trace_message(trace_message: str): """ Adds a message to be used as a prefix when emitting warnings and errors. """ @@ -83,11 +86,11 @@ def add_trace_message(trace_message): _TRACES.pop() -def _get_current_trace(): +def _get_current_trace() -> str: return ''.join(f"{trace}: " for trace in _TRACES if trace) -def has_override(obj, prop): +def has_override(obj, prop: str) -> bool: if isinstance(obj, functools.partial): obj = obj.func if not hasattr(obj, '_spectacular_annotation'): @@ -97,7 +100,7 @@ def has_override(obj, prop): return True -def get_override(obj, prop, default=None): +def get_override(obj, prop: str, default: Any = None) -> Any: if isinstance(obj, functools.partial): obj = obj.func if not has_override(obj, prop): @@ -105,7 +108,7 @@ def get_override(obj, prop, default=None): return obj._spectacular_annotation[prop] -def set_override(obj, prop, value): +def set_override(obj, prop: str, value: Any): if not hasattr(obj, '_spectacular_annotation'): obj._spectacular_annotation = {} elif '_spectacular_annotation' not in obj.__dict__: diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index 8850a933..051cd252 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -3,7 +3,7 @@ from django.urls import URLPattern, URLResolver from rest_framework import views, viewsets -from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore +from rest_framework.schemas.generators import BaseSchemaGenerator from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator from rest_framework.settings import api_settings diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 0ea88bc8..e4ed735d 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -1,7 +1,7 @@ import copy import re -import typing from collections import defaultdict +from typing import Any, Dict, List, Optional, Union import uritemplate from django.core import exceptions as django_exceptions @@ -13,7 +13,7 @@ from rest_framework.generics import CreateAPIView, GenericAPIView, ListCreateAPIView from rest_framework.mixins import ListModelMixin from rest_framework.schemas.inspectors import ViewInspector -from rest_framework.schemas.utils import get_pk_description # type: ignore +from rest_framework.schemas.utils import get_pk_description from rest_framework.settings import api_settings from rest_framework.utils.model_meta import get_field_info from rest_framework.views import APIView @@ -36,7 +36,7 @@ ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter, OpenApiResponse +from drf_spectacular.utils import Direction, OpenApiParameter, OpenApiResponse, _SerializerType class AutoSchema(ViewInspector): @@ -135,7 +135,7 @@ def _is_create_operation(self): return True return False - def get_override_parameters(self): + def get_override_parameters(self) -> List[Union[OpenApiParameter, _SerializerType]]: """ override this for custom behaviour """ return [] @@ -189,13 +189,13 @@ def _process_override_parameters(self): warn(f'could not resolve parameter annotation {parameter}. Skipping.') return result - def _get_format_parameters(self): + def _get_format_parameters(self) -> List[dict]: parameters = [] formats = self.map_renderers('format') if api_settings.URL_FORMAT_OVERRIDE and len(formats) > 1: parameters.append(build_parameter_type( name=api_settings.URL_FORMAT_OVERRIDE, - schema=build_basic_type(OpenApiTypes.STR), + schema=build_basic_type(OpenApiTypes.STR), # type: ignore location=OpenApiParameter.QUERY, enum=formats )) @@ -243,14 +243,14 @@ def dict_helper(parameters): else: return list(parameters.values()) - def get_description(self): + def get_description(self) -> str: # type: ignore """ override this for custom behaviour """ action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None) view_doc = get_doc(self.view.__class__) action_doc = get_doc(action_or_method) return action_doc or view_doc - def get_summary(self): + def get_summary(self) -> Optional[str]: """ override this for custom behaviour """ return None @@ -305,21 +305,21 @@ def get_auth(self): auths.append({}) return auths - def get_request_serializer(self) -> typing.Any: + def get_request_serializer(self) -> Any: """ override this for custom behaviour """ return self._get_serializer() - def get_response_serializers(self) -> typing.Any: + def get_response_serializers(self) -> Any: """ override this for custom behaviour """ return self._get_serializer() - def get_tags(self) -> typing.List[str]: + def get_tags(self) -> List[str]: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # use first non-parameter path part as tag return tokenized_path[:1] - def get_extensions(self) -> typing.Dict[str, typing.Any]: + def get_extensions(self) -> Dict[str, Any]: return {} def get_operation_id(self): @@ -1168,7 +1168,7 @@ def _get_response_bodies(self): schema['description'] = _('Unspecified response body') return {'200': self._get_response_for_code(schema, '200')} - def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]: + def _unwrap_list_serializer(self, serializer, direction) -> Optional[dict]: if is_field(serializer): return self._map_serializer_field(serializer, direction) elif is_basic_serializer(serializer): @@ -1277,7 +1277,11 @@ def _get_response_headers_for_code(self, status_code) -> dict: elif is_serializer(parameter.type): schema = self.resolve_serializer(parameter.type, 'response').ref else: - schema = parameter.type + schema = parameter.type # type: ignore + + if not schema: + warn(f'response parameter {parameter.name} requires non-empty schema') + continue if parameter.location not in [OpenApiParameter.HEADER, OpenApiParameter.COOKIE]: warn(f'incompatible location type ignored for response parameter {parameter.name}') @@ -1303,7 +1307,7 @@ def _get_response_headers_for_code(self, status_code) -> dict: return result - def _get_serializer_name(self, serializer, direction): + def _get_serializer_name(self, serializer, direction: Direction) -> str: serializer_extension = OpenApiSerializerExtension.get_match(serializer) if serializer_extension and serializer_extension.get_name(): # library override mechanisms @@ -1320,6 +1324,8 @@ def _get_serializer_name(self, serializer, direction): else: name = serializer.__class__.__name__ + assert name + if name.endswith('Serializer'): name = name[:-10] @@ -1331,7 +1337,7 @@ def _get_serializer_name(self, serializer, direction): return name - def resolve_serializer(self, serializer, direction) -> ResolvedComponent: + def resolve_serializer(self, serializer: _SerializerType, direction: Direction) -> ResolvedComponent: assert_basic_serializer(serializer) serializer = force_instance(serializer) diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index ae151cd8..21916518 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -12,7 +12,7 @@ from collections import OrderedDict, defaultdict from decimal import Decimal from enum import Enum -from typing import Any, DefaultDict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, DefaultDict, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union import inflection import uritemplate @@ -38,12 +38,16 @@ from rest_framework.utils.mediatypes import _MediaType from uritemplate import URITemplate -from drf_spectacular.drainage import Literal, _TypedDictMeta, cache, error, warn +from drf_spectacular.drainage import Literal, TypeGuard, _TypedDictMeta, cache, error, warn from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import ( DJANGO_PATH_CONVERTER_MAPPING, OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes, + _KnownPythonTypes, +) +from drf_spectacular.utils import ( + OpenApiParameter, _FieldType, _ListSerializerType, _ParameterLocationType, _SchemaType, + _SerializerType, ) -from drf_spectacular.utils import OpenApiParameter try: from django.db.models.enums import Choices # only available in Django>3 @@ -53,14 +57,14 @@ class Choices: # type: ignore # types.UnionType was added in Python 3.10 for new PEP 604 pipe union syntax if hasattr(types, 'UnionType'): - UNION_TYPES: Tuple[Any, ...] = (typing.Union, types.UnionType) # type: ignore + UNION_TYPES: Tuple[Any, ...] = (Union, types.UnionType) # type: ignore else: - UNION_TYPES = (typing.Union,) + UNION_TYPES = (Union,) if sys.version_info >= (3, 8): - CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore + CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) else: - CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore + CACHED_PROPERTY_FUNCS = (cached_property,) T = TypeVar('T') @@ -82,7 +86,7 @@ def force_instance(serializer_or_field): return serializer_or_field -def is_serializer(obj) -> bool: +def is_serializer(obj) -> TypeGuard[_SerializerType]: from drf_spectacular.serializers import OpenApiSerializerExtension return ( isinstance(force_instance(obj), serializers.BaseSerializer) @@ -90,21 +94,21 @@ def is_serializer(obj) -> bool: ) -def is_list_serializer(obj) -> bool: +def is_list_serializer(obj) -> TypeGuard[_ListSerializerType]: return isinstance(force_instance(obj), serializers.ListSerializer) -def is_basic_serializer(obj) -> bool: +def is_basic_serializer(obj) -> TypeGuard[_SerializerType]: return is_serializer(obj) and not is_list_serializer(obj) -def is_field(obj): +def is_field(obj) -> TypeGuard[_FieldType]: # make sure obj is a serializer field and nothing else. # guard against serializers because BaseSerializer(Field) return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj) -def is_basic_type(obj, allow_none=True): +def is_basic_type(obj, allow_none=True) -> TypeGuard[_KnownPythonTypes]: if not isinstance(obj, collections.abc.Hashable): return False if not allow_none and (obj is None or obj is OpenApiTypes.NONE): @@ -112,7 +116,7 @@ def is_basic_type(obj, allow_none=True): return obj in get_openapi_type_mapping() or obj in PYTHON_TYPE_MAPPING -def is_patched_serializer(serializer, direction): +def is_patched_serializer(serializer, direction) -> bool: return bool( spectacular_settings.COMPONENT_SPLIT_PATCH and serializer.partial @@ -121,13 +125,13 @@ def is_patched_serializer(serializer, direction): ) -def is_trivial_string_variation(a: str, b: str): +def is_trivial_string_variation(a: str, b: str) -> bool: a = (a or '').strip().lower().replace(' ', '_').replace('-', '_') b = (b or '').strip().lower().replace(' ', '_').replace('-', '_') return a == b -def assert_basic_serializer(serializer): +def assert_basic_serializer(serializer) -> None: assert is_basic_serializer(serializer), ( f'internal assumption violated because we expected a basic serializer here and ' f'instead got a "{serializer}". This may be the result of another app doing ' @@ -173,7 +177,7 @@ def get_view_model(view, emit_warnings=True): ) -def get_doc(obj): +def get_doc(obj) -> str: """ get doc string with fallback on obj's base classes (ignoring DRF documentation). """ def post_cleanup(doc: str): # also clean up trailing whitespace for each line @@ -197,7 +201,7 @@ def safe_index(lst, item): return '' -def get_type_hints(obj): +def get_type_hints(obj) -> Dict[str, Any]: """ unpack wrapped partial object and use actual func object """ if isinstance(obj, functools.partial): obj = obj.func @@ -212,7 +216,7 @@ def get_openapi_type_mapping(): } -def build_generic_type(): +def build_generic_type() -> dict: if spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES is None: return {'type': 'object'} elif spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES == 'bool': @@ -221,7 +225,7 @@ def build_generic_type(): return {'type': 'object', 'additionalProperties': {}} -def build_basic_type(obj): +def build_basic_type(obj: Union[_KnownPythonTypes, OpenApiTypes]) -> Optional[_SchemaType]: """ resolve either enum or actual type and yield schema template for modification """ @@ -237,7 +241,7 @@ def build_basic_type(obj): return dict(openapi_type_mapping[OpenApiTypes.STR]) -def build_array_type(schema, min_length=None, max_length=None): +def build_array_type(schema, min_length=None, max_length=None) -> _SchemaType: schema = {'type': 'array', 'items': schema} if min_length is not None: schema['minLength'] = min_length @@ -247,12 +251,12 @@ def build_array_type(schema, min_length=None, max_length=None): def build_object_type( - properties=None, + properties: Optional[List[dict]] = None, required=None, - description=None, + description: Optional[str] = None, **kwargs -): - schema = {'type': 'object'} +) -> _SchemaType: + schema: _SchemaType = {'type': 'object'} if description: schema['description'] = description.strip() if properties: @@ -265,14 +269,14 @@ def build_object_type( return schema -def build_media_type_object(schema, examples=None): +def build_media_type_object(schema, examples=None) -> _SchemaType: media_type_object = {'schema': schema} if examples: media_type_object['examples'] = examples return media_type_object -def build_examples_list(examples): +def build_examples_list(examples) -> _SchemaType: schema = {} for example in examples: normalized_name = inflection.camelize(example.name.replace(' ', '_')) @@ -292,9 +296,9 @@ def build_examples_list(examples): def build_parameter_type( - name, - schema, - location, + name: str, + schema: dict, + location: _ParameterLocationType, required=False, description=None, enum=None, @@ -305,7 +309,7 @@ def build_parameter_type( allow_blank=True, examples=None, extensions=None, -): +) -> _SchemaType: irrelevant_field_meta = ['readOnly', 'writeOnly'] if location == OpenApiParameter.PATH: irrelevant_field_meta += ['nullable', 'default'] @@ -337,11 +341,11 @@ def build_parameter_type( return schema -def build_choice_field(field): +def build_choice_field(field) -> _SchemaType: choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): - type = 'boolean' + type: Optional[str] = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` @@ -357,7 +361,7 @@ def build_choice_field(field): if field.allow_null: choices.append(None) - schema = { + schema: _SchemaType = { # The value of `enum` keyword MUST be an array and SHOULD be unique. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 'enum': choices @@ -372,7 +376,7 @@ def build_choice_field(field): return schema -def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None): +def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None) -> _SchemaType: """ Either build a bearer scheme or a fallback due to OpenAPI 3.0.3 limitations """ # normalize Django header quirks if header_name.startswith('HTTP_'): @@ -396,7 +400,7 @@ def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format } -def build_root_object(paths, components, version): +def build_root_object(paths, components, version) -> _SchemaType: settings = spectacular_settings if settings.VERSION and version: version = f'{settings.VERSION} ({version})' @@ -429,7 +433,7 @@ def build_root_object(paths, components, version): return root -def safe_ref(schema): +def safe_ref(schema: _SchemaType) -> _SchemaType: """ ensure that $ref has its own context and does not remove potential sibling entries when $ref is substituted. @@ -439,7 +443,7 @@ def safe_ref(schema): return schema -def append_meta(schema, meta): +def append_meta(schema: _SchemaType, meta: _SchemaType) -> _SchemaType: return safe_ref({**schema, **meta}) @@ -853,7 +857,7 @@ def resolve_regex_path_parameter(path_regex, variable): return None -def is_versioning_supported(versioning_class): +def is_versioning_supported(versioning_class) -> bool: return issubclass(versioning_class, ( versioning.URLPathVersioning, versioning.NamespaceVersioning, @@ -861,7 +865,7 @@ def is_versioning_supported(versioning_class): )) -def operation_matches_version(view, requested_version): +def operation_matches_version(view, requested_version) -> bool: try: version, _ = view.determine_version(view.request, **view.kwargs) except exceptions.NotAcceptable: @@ -907,7 +911,7 @@ def modify_for_versioning(patterns, method, path, view, requested_version): return path -def analyze_named_regex_pattern(path): +def analyze_named_regex_pattern(path: str) -> Dict[str, str]: """ safely extract named groups and their pattern from given regex pattern """ result = {} stack = 0 @@ -1170,7 +1174,7 @@ def resolve_type_hint(hint): raise UnableToProceedError() -def whitelisted(obj: object, classes: List[Type[object]], exact=False): +def whitelisted(obj: object, classes: List[Type[object]], exact=False) -> bool: if not classes: return True if exact: diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index f0f7edd7..77519af7 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from rest_framework.fields import Field, empty -from rest_framework.serializers import Serializer +from rest_framework.serializers import ListSerializer, Serializer from rest_framework.settings import api_settings from drf_spectacular.drainage import ( @@ -10,9 +10,11 @@ ) from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes +_ListSerializerType = Union[ListSerializer, Type[ListSerializer]] _SerializerType = Union[Serializer, Type[Serializer]] _FieldType = Union[Field, Type[Field]] _ParameterLocationType = Literal['query', 'path', 'header', 'cookie'] +_SchemaType = Dict[str, Any] Direction = Literal['request', 'response'] diff --git a/requirements/docs.txt b/requirements/docs.txt index 486e4406..ca42c48f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,2 +1,3 @@ Sphinx>=4.1.0 sphinx_rtd_theme>=0.5.1 +typing-extensions