From dee63283daed1a61a2a33d2fa5872e41c9b94b24 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 11 Nov 2021 11:44:45 +0100 Subject: [PATCH] improve mypy typing #600 --- drf_spectacular/drainage.py | 41 ++++++++------ drf_spectacular/extensions.py | 10 ++-- drf_spectacular/generators.py | 2 +- drf_spectacular/openapi.py | 71 ++++++++++++++--------- drf_spectacular/plumbing.py | 104 ++++++++++++++++++---------------- drf_spectacular/utils.py | 4 +- requirements/docs.txt | 1 + tox.ini | 17 ++++++ 8 files changed, 147 insertions(+), 103 deletions(-) diff --git a/drf_spectacular/drainage.py b/drf_spectacular/drainage.py index f064120f..4c14b049 100644 --- a/drf_spectacular/drainage.py +++ b/drf_spectacular/drainage.py @@ -2,16 +2,21 @@ import functools import sys from collections import defaultdict -from typing import DefaultDict +from typing import Any, Callable, DefaultDict, List, TypeVar if sys.version_info >= (3, 8): from typing import ( # type: ignore[attr-defined] # noqa: F401 Final, Literal, TypedDict, _TypedDictMeta, ) else: - from typing_extensions import ( # type: ignore[attr-defined] # noqa: F401 - Final, Literal, TypedDict, _TypedDictMeta, - ) + from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401 + +if sys.version_info >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 + +F = TypeVar('F', bound=Callable[..., Any]) class GeneratorStats: @@ -35,11 +40,11 @@ def silence(self): finally: self.silent = tmp - def reset(self): + def reset(self) -> None: self._warn_cache.clear() self._error_cache.clear() - def emit(self, msg, severity): + def emit(self, msg: str, severity: str) -> None: assert severity in ['warning', 'error'] msg = _get_current_trace() + str(msg) cache = self._warn_cache if severity == 'warning' else self._error_cache @@ -47,7 +52,7 @@ def emit(self, msg, severity): print(f'{severity.capitalize()} #{len(cache)}: {msg}', file=sys.stderr) cache[msg] += 1 - def emit_summary(self): + def emit_summary(self) -> None: if not self.silent and (self._warn_cache or self._error_cache): print( f'\nSchema generation summary:\n' @@ -60,7 +65,7 @@ def emit_summary(self): GENERATOR_STATS = GeneratorStats() -def warn(msg, delayed=None): +def warn(msg: str, delayed: Any = None): if delayed: warnings = get_override(delayed, 'warnings', []) warnings.append(msg) @@ -69,7 +74,7 @@ def warn(msg, delayed=None): GENERATOR_STATS.emit(msg, 'warning') -def error(msg, delayed=None): +def error(msg: str, delayed: Any = None): if delayed: errors = get_override(delayed, 'errors', []) errors.append(msg) @@ -78,7 +83,7 @@ def error(msg, delayed=None): GENERATOR_STATS.emit(msg, 'error') -def reset_generator_stats(): +def reset_generator_stats() -> None: GENERATOR_STATS.reset() @@ -86,7 +91,7 @@ def reset_generator_stats(): @contextlib.contextmanager -def add_trace_message(trace_message): +def add_trace_message(trace_message: str): """ Adds a message to be used as a prefix when emitting warnings and errors. """ @@ -95,11 +100,11 @@ def add_trace_message(trace_message): _TRACES.pop() -def _get_current_trace(): +def _get_current_trace() -> str: return ''.join(f"{trace}: " for trace in _TRACES if trace) -def has_override(obj, prop): +def has_override(obj: Any, prop: str) -> bool: if isinstance(obj, functools.partial): obj = obj.func if not hasattr(obj, '_spectacular_annotation'): @@ -109,7 +114,7 @@ def has_override(obj, prop): return True -def get_override(obj, prop, default=None): +def get_override(obj: Any, prop: str, default: Any = None) -> Any: if isinstance(obj, functools.partial): obj = obj.func if not has_override(obj, prop): @@ -117,7 +122,7 @@ def get_override(obj, prop, default=None): return obj._spectacular_annotation[prop] -def set_override(obj, prop, value): +def set_override(obj: Any, prop: str, value: Any) -> Any: if not hasattr(obj, '_spectacular_annotation'): obj._spectacular_annotation = {} elif '_spectacular_annotation' not in obj.__dict__: @@ -126,7 +131,7 @@ def set_override(obj, prop, value): return obj -def get_view_method_names(view, schema=None): +def get_view_method_names(view, schema=None) -> List[str]: schema = schema or view.schema return [ item for item in dir(view) if callable(getattr(view, item)) and ( @@ -164,6 +169,6 @@ def wrapped_method(self, request, *args, **kwargs): return wrapped_method -def cache(user_function): +def cache(user_function: F) -> F: """ simple polyfill for python < 3.9 """ - return functools.lru_cache(maxsize=None)(user_function) + return functools.lru_cache(maxsize=None)(user_function) # type: ignore diff --git a/drf_spectacular/extensions.py b/drf_spectacular/extensions.py index 8e4ed150..25ea6425 100644 --- a/drf_spectacular/extensions.py +++ b/drf_spectacular/extensions.py @@ -25,7 +25,7 @@ class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthentic ``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object `_ """ - _registry: List['OpenApiAuthenticationExtension'] = [] + _registry: List[Type['OpenApiAuthenticationExtension']] = [] name: Union[str, List[str]] @@ -53,7 +53,7 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt ``map_serializer()`` is expected to return a valid `OpenAPI schema object `_. """ - _registry: List['OpenApiSerializerExtension'] = [] + _registry: List[Type['OpenApiSerializerExtension']] = [] def get_name(self) -> Optional[str]: """ return str for overriding default name extraction """ @@ -76,7 +76,7 @@ class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializ ``map_serializer_field()`` is expected to return a valid `OpenAPI schema object `_. """ - _registry: List['OpenApiSerializerFieldExtension'] = [] + _registry: List[Type['OpenApiSerializerFieldExtension']] = [] def get_name(self) -> Optional[str]: """ return str for breaking out field schema into separate named component """ @@ -96,7 +96,7 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']): ``ViewSet`` et al.). The discovered original view instance can be accessed with ``self.target`` and be subclassed if desired. """ - _registry: List['OpenApiViewExtension'] = [] + _registry: List[Type['OpenApiViewExtension']] = [] @classmethod def _load_class(cls): @@ -123,7 +123,7 @@ class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension'] Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate the appropriate raw dict objects. """ - _registry: List['OpenApiFilterExtension'] = [] + _registry: List[Type['OpenApiFilterExtension']] = [] @abstractmethod def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[dict]: diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index 83812dc7..ce84d9bb 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -3,7 +3,7 @@ from django.urls import URLPattern, URLResolver from rest_framework import views, viewsets -from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore +from rest_framework.schemas.generators import BaseSchemaGenerator from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator from rest_framework.settings import api_settings diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 759a9d15..c863ccf9 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -1,7 +1,7 @@ import copy import re -import typing from collections import defaultdict +from typing import List, Optional, Union import uritemplate from django.core import exceptions as django_exceptions @@ -13,7 +13,7 @@ from rest_framework.generics import CreateAPIView, GenericAPIView, ListCreateAPIView from rest_framework.mixins import ListModelMixin from rest_framework.schemas.inspectors import ViewInspector -from rest_framework.schemas.utils import get_pk_description # type: ignore +from rest_framework.schemas.utils import get_pk_description from rest_framework.settings import api_settings from rest_framework.utils.model_meta import get_field_info from rest_framework.views import APIView @@ -38,7 +38,9 @@ ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiCallback, OpenApiParameter, OpenApiResponse +from drf_spectacular.utils import ( + Direction, OpenApiCallback, OpenApiParameter, OpenApiResponse, _SchemaType, _SerializerType, +) class AutoSchema(ViewInspector): @@ -50,14 +52,21 @@ class AutoSchema(ViewInspector): 'delete': 'destroy', } - def get_operation(self, path, path_regex, path_prefix, method, registry: ComponentRegistry): + def get_operation( + self, + path: str, + path_regex: str, + path_prefix: str, + method: str, + registry: ComponentRegistry + ) -> _SchemaType: self.registry = registry self.path = path self.path_regex = path_regex self.path_prefix = path_prefix self.method = method.upper() - operation = {'operationId': self.get_operation_id()} + operation: _SchemaType = {'operationId': self.get_operation_id()} description = self.get_description() if description: @@ -103,7 +112,7 @@ def get_operation(self, path, path_regex, path_prefix, method, registry: Compone return operation - def _is_list_view(self, serializer=None): + def _is_list_view(self, serializer=None) -> bool: """ partially heuristic approach to determine if a view yields an object or a list of objects. used for operationId naming, array building and pagination. @@ -136,7 +145,7 @@ def _is_list_view(self, serializer=None): return False - def _is_create_operation(self): + def _is_create_operation(self) -> bool: if self.method != 'POST': return False if getattr(self.view, 'action', None) == 'create': @@ -145,7 +154,7 @@ def _is_create_operation(self): return True return False - def get_override_parameters(self): + def get_override_parameters(self) -> List[Union[OpenApiParameter, _SerializerType]]: """ override this for custom behaviour """ return [] @@ -212,13 +221,13 @@ def _process_override_parameters(self): warn(f'could not resolve parameter annotation {parameter}. Skipping.') return result - def _get_format_parameters(self): + def _get_format_parameters(self) -> List[dict]: parameters = [] formats = self.map_renderers('format') if api_settings.URL_FORMAT_OVERRIDE and len(formats) > 1: parameters.append(build_parameter_type( name=api_settings.URL_FORMAT_OVERRIDE, - schema=build_basic_type(OpenApiTypes.STR), + schema=build_basic_type(OpenApiTypes.STR), # type: ignore location=OpenApiParameter.QUERY, enum=formats )) @@ -266,14 +275,14 @@ def dict_helper(parameters): else: return list(parameters.values()) - def get_description(self): + def get_description(self) -> str: # type: ignore """ override this for custom behaviour """ action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None) view_doc = get_doc(self.view.__class__) action_doc = get_doc(action_or_method) return action_doc or view_doc - def get_summary(self): + def get_summary(self) -> Optional[str]: """ override this for custom behaviour """ return None @@ -339,21 +348,21 @@ def get_auth(self): auths.append({}) return auths - def get_request_serializer(self) -> typing.Any: + def get_request_serializer(self) -> Optional[_SerializerType]: """ override this for custom behaviour """ return self._get_serializer() - def get_response_serializers(self) -> typing.Any: + def get_response_serializers(self) -> Optional[_SerializerType]: """ override this for custom behaviour """ return self._get_serializer() - def get_tags(self) -> typing.List[str]: + def get_tags(self) -> List[str]: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # use first non-parameter path part as tag return tokenized_path[:1] - def get_extensions(self) -> typing.Dict[str, typing.Any]: + def get_extensions(self) -> _SchemaType: return {} def _get_callbacks(self): @@ -414,11 +423,11 @@ def _get_callbacks(self): return result - def get_callbacks(self) -> typing.List[OpenApiCallback]: + def get_callbacks(self) -> List[OpenApiCallback]: """ override this for custom behaviour """ return [] - def get_operation_id(self): + def get_operation_id(self) -> str: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # replace dashes as they can be problematic later in code generation @@ -437,11 +446,11 @@ def get_operation_id(self): return '_'.join(tokenized_path + [action]) - def is_deprecated(self): + def is_deprecated(self) -> bool: """ override this for custom behaviour """ return False - def _tokenize_path(self): + def _tokenize_path(self) -> List[str]: # remove path prefix path = re.sub( pattern=self.path_prefix, @@ -452,8 +461,8 @@ def _tokenize_path(self): # remove path variables path = re.sub(pattern=r'\{[\w\-]+\}', repl='', string=path) # cleanup and tokenize remaining parts. - path = path.rstrip('/').lstrip('/').split('/') - return [t for t in path if t] + tokenized_path = path.rstrip('/').lstrip('/').split('/') + return [t for t in tokenized_path if t] def _resolve_path_parameters(self, variables): model = get_view_model(self.view, emit_warnings=False) @@ -1238,7 +1247,7 @@ def _get_request_for_media_type(self, serializer): request_body_required = False return schema, request_body_required - def _get_response_bodies(self): + def _get_response_bodies(self) -> _SchemaType: response_serializers = self.get_response_serializers() if ( @@ -1272,10 +1281,10 @@ def _get_response_bodies(self): f'Defaulting to generic free-form object.' ) schema = build_basic_type(OpenApiTypes.OBJECT) - schema['description'] = _('Unspecified response body') + schema['description'] = _('Unspecified response body') # type: ignore return {'200': self._get_response_for_code(schema, '200')} - def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]: + def _unwrap_list_serializer(self, serializer, direction: Direction) -> Optional[_SchemaType]: if is_field(serializer): return self._map_serializer_field(serializer, direction) elif is_basic_serializer(serializer): @@ -1395,7 +1404,11 @@ def _get_response_headers_for_code(self, status_code) -> dict: elif is_serializer(parameter.type): schema = self.resolve_serializer(parameter.type, 'response').ref else: - schema = parameter.type + schema = parameter.type # type: ignore + + if not schema: + warn(f'response parameter {parameter.name} requires non-empty schema') + continue if parameter.location not in [OpenApiParameter.HEADER, OpenApiParameter.COOKIE]: warn(f'incompatible location type ignored for response parameter {parameter.name}') @@ -1422,7 +1435,7 @@ def _get_response_headers_for_code(self, status_code) -> dict: return result - def _get_serializer_name(self, serializer, direction): + def _get_serializer_name(self, serializer, direction: Direction) -> str: serializer_extension = OpenApiSerializerExtension.get_match(serializer) if serializer_extension and serializer_extension.get_name(): # library override mechanisms @@ -1439,6 +1452,8 @@ def _get_serializer_name(self, serializer, direction): else: name = serializer.__class__.__name__ + assert name + if name.endswith('Serializer'): name = name[:-10] @@ -1457,7 +1472,7 @@ def _get_serializer_name(self, serializer, direction): return name - def resolve_serializer(self, serializer, direction) -> ResolvedComponent: + def resolve_serializer(self, serializer: _SerializerType, direction: Direction) -> ResolvedComponent: assert_basic_serializer(serializer) serializer = force_instance(serializer) diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 3e263fb6..f49d7131 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -12,7 +12,7 @@ from collections import OrderedDict, defaultdict from decimal import Decimal from enum import Enum -from typing import Any, DefaultDict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, DefaultDict, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union import inflection import uritemplate @@ -39,12 +39,16 @@ from rest_framework.utils.mediatypes import _MediaType from uritemplate import URITemplate -from drf_spectacular.drainage import Literal, _TypedDictMeta, cache, error, warn +from drf_spectacular.drainage import Literal, TypeGuard, _TypedDictMeta, cache, error, warn from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import ( DJANGO_PATH_CONVERTER_MAPPING, OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes, + _KnownPythonTypes, +) +from drf_spectacular.utils import ( + OpenApiExample, OpenApiParameter, _FieldType, _ListSerializerType, _ParameterLocationType, + _SchemaType, _SerializerType, ) -from drf_spectacular.utils import OpenApiParameter try: from django.db.models.enums import Choices # only available in Django>3 @@ -54,14 +58,14 @@ class Choices: # type: ignore # types.UnionType was added in Python 3.10 for new PEP 604 pipe union syntax if hasattr(types, 'UnionType'): - UNION_TYPES: Tuple[Any, ...] = (typing.Union, types.UnionType) # type: ignore + UNION_TYPES: Tuple[Any, ...] = (Union, types.UnionType) # type: ignore else: - UNION_TYPES = (typing.Union,) + UNION_TYPES = (Union,) if sys.version_info >= (3, 8): - CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore + CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) else: - CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore + CACHED_PROPERTY_FUNCS = (cached_property,) T = TypeVar('T') @@ -83,7 +87,7 @@ def force_instance(serializer_or_field): return serializer_or_field -def is_serializer(obj) -> bool: +def is_serializer(obj) -> TypeGuard[_SerializerType]: from drf_spectacular.serializers import OpenApiSerializerExtension return ( isinstance(force_instance(obj), serializers.BaseSerializer) @@ -91,7 +95,7 @@ def is_serializer(obj) -> bool: ) -def is_list_serializer(obj) -> bool: +def is_list_serializer(obj) -> TypeGuard[_ListSerializerType]: return isinstance(force_instance(obj), serializers.ListSerializer) @@ -107,17 +111,17 @@ def is_list_serializer_customized(obj) -> bool: ) -def is_basic_serializer(obj) -> bool: +def is_basic_serializer(obj) -> TypeGuard[_SerializerType]: return is_serializer(obj) and not is_list_serializer(obj) -def is_field(obj): +def is_field(obj) -> TypeGuard[_FieldType]: # make sure obj is a serializer field and nothing else. # guard against serializers because BaseSerializer(Field) return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj) -def is_basic_type(obj, allow_none=True): +def is_basic_type(obj, allow_none=True) -> TypeGuard[_KnownPythonTypes]: if not isinstance(obj, collections.abc.Hashable): return False if not allow_none and (obj is None or obj is OpenApiTypes.NONE): @@ -125,7 +129,7 @@ def is_basic_type(obj, allow_none=True): return obj in get_openapi_type_mapping() or obj in PYTHON_TYPE_MAPPING -def is_patched_serializer(serializer, direction): +def is_patched_serializer(serializer, direction) -> bool: return bool( spectacular_settings.COMPONENT_SPLIT_PATCH and serializer.partial @@ -134,13 +138,13 @@ def is_patched_serializer(serializer, direction): ) -def is_trivial_string_variation(a: str, b: str): +def is_trivial_string_variation(a: str, b: str) -> bool: a = (a or '').strip().lower().replace(' ', '_').replace('-', '_') b = (b or '').strip().lower().replace(' ', '_').replace('-', '_') return a == b -def assert_basic_serializer(serializer): +def assert_basic_serializer(serializer) -> None: assert is_basic_serializer(serializer), ( f'internal assumption violated because we expected a basic serializer here and ' f'instead got a "{serializer}". This may be the result of another app doing ' @@ -188,9 +192,9 @@ def get_view_model(view, emit_warnings=True): ) -def get_doc(obj): +def get_doc(obj) -> str: """ get doc string with fallback on obj's base classes (ignoring DRF documentation). """ - def post_cleanup(doc: str): + def post_cleanup(doc: str) -> str: # also clean up trailing whitespace for each line return '\n'.join(line.rstrip() for line in doc.rstrip().split('\n')) @@ -212,7 +216,7 @@ def safe_index(lst, item): return '' -def get_type_hints(obj): +def get_type_hints(obj) -> Dict[str, Any]: """ unpack wrapped partial object and use actual func object """ if isinstance(obj, functools.partial): obj = obj.func @@ -227,7 +231,7 @@ def get_openapi_type_mapping(): } -def build_generic_type(): +def build_generic_type() -> _SchemaType: if spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES is None: return {'type': 'object'} elif spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES == 'bool': @@ -236,7 +240,7 @@ def build_generic_type(): return {'type': 'object', 'additionalProperties': {}} -def build_basic_type(obj): +def build_basic_type(obj: Union[_KnownPythonTypes, OpenApiTypes]) -> Optional[_SchemaType]: """ resolve either enum or actual type and yield schema template for modification """ @@ -252,7 +256,7 @@ def build_basic_type(obj): return dict(openapi_type_mapping[OpenApiTypes.STR]) -def build_array_type(schema, min_length=None, max_length=None): +def build_array_type(schema: _SchemaType, min_length=None, max_length=None) -> _SchemaType: schema = {'type': 'array', 'items': schema} if min_length is not None: schema['minLength'] = min_length @@ -262,12 +266,12 @@ def build_array_type(schema, min_length=None, max_length=None): def build_object_type( - properties=None, + properties: Optional[_SchemaType] = None, required=None, - description=None, + description: Optional[str] = None, **kwargs -): - schema = {'type': 'object'} +) -> _SchemaType: + schema: _SchemaType = {'type': 'object'} if description: schema['description'] = description.strip() if properties: @@ -280,14 +284,14 @@ def build_object_type( return schema -def build_media_type_object(schema, examples=None): +def build_media_type_object(schema: _SchemaType, examples=None) -> _SchemaType: media_type_object = {'schema': schema} if examples: media_type_object['examples'] = examples return media_type_object -def build_examples_list(examples): +def build_examples_list(examples: List[OpenApiExample]) -> _SchemaType: schema = {} for example in examples: normalized_name = inflection.camelize(example.name.replace(' ', '_')) @@ -307,9 +311,9 @@ def build_examples_list(examples): def build_parameter_type( - name, - schema, - location, + name: str, + schema: _SchemaType, + location: _ParameterLocationType, required=False, description=None, enum=None, @@ -321,7 +325,7 @@ def build_parameter_type( allow_blank=True, examples=None, extensions=None, -): +) -> _SchemaType: irrelevant_field_meta = ['readOnly', 'writeOnly'] if location == OpenApiParameter.PATH: irrelevant_field_meta += ['nullable', 'default'] @@ -355,11 +359,11 @@ def build_parameter_type( return schema -def build_choice_field(field): +def build_choice_field(field) -> _SchemaType: choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): - type = 'boolean' + type: Optional[str] = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` @@ -375,7 +379,7 @@ def build_choice_field(field): if field.allow_null: choices.append(None) - schema = { + schema: _SchemaType = { # The value of `enum` keyword MUST be an array and SHOULD be unique. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 'enum': choices @@ -390,7 +394,7 @@ def build_choice_field(field): return schema -def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None): +def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None) -> _SchemaType: """ Either build a bearer scheme or a fallback due to OpenAPI 3.0.3 limitations """ # normalize Django header quirks if header_name.startswith('HTTP_'): @@ -414,7 +418,7 @@ def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format } -def build_root_object(paths, components, version): +def build_root_object(paths, components, version) -> _SchemaType: settings = spectacular_settings if settings.VERSION and version: version = f'{settings.VERSION} ({version})' @@ -448,7 +452,7 @@ def build_root_object(paths, components, version): return root -def safe_ref(schema): +def safe_ref(schema: _SchemaType) -> _SchemaType: """ ensure that $ref has its own context and does not remove potential sibling entries when $ref is substituted. @@ -458,7 +462,7 @@ def safe_ref(schema): return schema -def append_meta(schema, meta): +def append_meta(schema: _SchemaType, meta: _SchemaType) -> _SchemaType: return safe_ref({**schema, **meta}) @@ -606,18 +610,18 @@ def __bool__(self): return bool(self.name and self.type and self.object) @property - def key(self): + def key(self) -> Tuple[str, str]: return self.name, self.type @property - def ref(self) -> dict: + def ref(self) -> _SchemaType: assert self.__bool__() return {'$ref': f'#/components/{self.type}/{self.name}'} class ComponentRegistry: def __init__(self): - self._components = {} + self._components: Dict[Tuple[str, str], ResolvedComponent] = {} def register(self, component: ResolvedComponent): if component in self: @@ -649,7 +653,7 @@ def __contains__(self, component): ) return True - def __getitem__(self, key): + def __getitem__(self, key) -> ResolvedComponent: if isinstance(key, ResolvedComponent): key = key.key return self._components[key] @@ -676,7 +680,7 @@ def build(self, extra_components) -> dict: class OpenApiGeneratorExtension(Generic[T], metaclass=ABCMeta): - _registry: List[T] = [] + _registry: List[Type[T]] = [] target_class: Union[None, str, Type[object]] = None match_subclasses = False priority = 0 @@ -702,7 +706,7 @@ def _load_class(cls): cls.target_class = None @classmethod - def _matches(cls, target) -> bool: + def _matches(cls, target: Any) -> bool: if isinstance(cls.target_class, str): cls._load_class() @@ -721,7 +725,7 @@ def get_match(cls, target) -> Optional[T]: return None -def deep_import_string(string): +def deep_import_string(string: str) -> Any: """ augmented import from string, e.g. MODULE.CLASS/OBJECT.ATTRIBUTE """ try: return import_string(string) @@ -769,7 +773,7 @@ def load_enum_name_overrides(): return overrides -def list_hash(lst): +def list_hash(lst: List[Any]) -> str: return hashlib.sha256(json.dumps(list(lst), sort_keys=True).encode()).hexdigest() @@ -870,7 +874,7 @@ def resolve_regex_path_parameter(path_regex, variable): return None -def is_versioning_supported(versioning_class): +def is_versioning_supported(versioning_class) -> bool: return issubclass(versioning_class, ( versioning.URLPathVersioning, versioning.NamespaceVersioning, @@ -878,7 +882,7 @@ def is_versioning_supported(versioning_class): )) -def operation_matches_version(view, requested_version): +def operation_matches_version(view, requested_version) -> bool: try: version, _ = view.determine_version(view.request, **view.kwargs) except exceptions.NotAcceptable: @@ -944,7 +948,7 @@ def modify_media_types_for_versioning(view, media_types: List[str]) -> List[str] ] -def analyze_named_regex_pattern(path): +def analyze_named_regex_pattern(path: str) -> Dict[str, str]: """ safely extract named groups and their pattern from given regex pattern """ result = {} stack = 0 @@ -1225,7 +1229,7 @@ def resolve_type_hint(hint): raise UnableToProceedError() -def whitelisted(obj: object, classes: List[Type[object]], exact=False): +def whitelisted(obj: object, classes: List[Type[object]], exact=False) -> bool: if not classes: return True if exact: diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index 3ddf6da4..5672e407 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -9,7 +9,7 @@ from typing_extensions import Final, Literal from rest_framework.fields import Field, empty -from rest_framework.serializers import Serializer +from rest_framework.serializers import ListSerializer, Serializer from rest_framework.settings import api_settings from drf_spectacular.drainage import ( @@ -17,9 +17,11 @@ ) from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes +_ListSerializerType = Union[ListSerializer, Type[ListSerializer]] _SerializerType = Union[Serializer, Type[Serializer]] _FieldType = Union[Field, Type[Field]] _ParameterLocationType = Literal['query', 'path', 'header', 'cookie'] +_SchemaType = Dict[str, Any] Direction = Literal['request', 'response'] diff --git a/requirements/docs.txt b/requirements/docs.txt index 486e4406..ca42c48f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,2 +1,3 @@ Sphinx>=4.1.0 sphinx_rtd_theme>=0.5.1 +typing-extensions diff --git a/tox.ini b/tox.ini index 91513eab..395e83f5 100644 --- a/tox.ini +++ b/tox.ini @@ -86,10 +86,27 @@ include_trailing_comma = true [mypy] python_version = 3.8 plugins = mypy_django_plugin.main,mypy_drf_plugin.main +warn_unused_configs = True +warn_redundant_casts = True +warn_unused_ignores = True [mypy.plugins.django-stubs] django_settings_module = "tests.settings" +[mypy-drf_spectacular.*] +strict_equality = True +no_implicit_optional = True +disallow_untyped_decorators = True +disallow_subclassing_any = True +;check_untyped_defs = True +;warn_return_any = True +;no_implicit_reexport = True +;disallow_incomplete_defs = True +;disallow_any_generics = True +;disallow_untyped_calls = True +;disallow_untyped_defs = True + + [mypy-rest_framework.compat.*] ignore_missing_imports = True