Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve mypy typing #600 #620

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion drf_spectacular/contrib/rest_polymorphic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer, warn,
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down
46 changes: 30 additions & 16 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import inspect
import sys
from collections import defaultdict
from typing import DefaultDict, List, Optional, Tuple
from typing import Any, Callable, DefaultDict, List, Optional, Tuple, TypeVar

if sys.version_info >= (3, 8):
from typing import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, TypedDict, _TypedDictMeta,
)
else:
from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401

if sys.version_info >= (3, 10):
from typing import TypeGuard # noqa: F401
else:
from typing_extensions import TypeGuard # noqa: F401

F = TypeVar('F', bound=Callable[..., Any])


class GeneratorStats:
Expand Down Expand Up @@ -37,20 +51,20 @@ def silence(self):
finally:
self.silent = tmp

def reset(self):
def reset(self) -> None:
self._warn_cache.clear()
self._error_cache.clear()

def enable_color(self):
def enable_color(self) -> None:
self._blue = '\033[0;34m'
self._red = '\033[0;31m'
self._yellow = '\033[0;33m'
self._clear = '\033[0m'

def enable_trace_lineno(self):
def enable_trace_lineno(self) -> None:
self._trace_lineno = True

def _get_current_trace(self):
def _get_current_trace(self) -> Tuple[Optional[str], str]:
source_locations = [t for t in self._traces if t[0]]
if source_locations:
sourcefile, lineno, _ = source_locations[-1]
Expand All @@ -60,7 +74,7 @@ def _get_current_trace(self):
breadcrumbs = ' > '.join(t[2] for t in self._traces)
return source_location, breadcrumbs

def emit(self, msg, severity):
def emit(self, msg: str, severity: str) -> None:
assert severity in ['warning', 'error']
cache = self._warn_cache if severity == 'warning' else self._error_cache

Expand All @@ -75,7 +89,7 @@ def emit(self, msg, severity):
print(msg, file=sys.stderr)
cache[msg] += 1

def emit_summary(self):
def emit_summary(self) -> None:
if not self.silent and (self._warn_cache or self._error_cache):
print(
f'\nSchema generation summary:\n'
Expand All @@ -88,7 +102,7 @@ def emit_summary(self):
GENERATOR_STATS = GeneratorStats()


def warn(msg, delayed=None):
def warn(msg: str, delayed: Any = None) -> None:
if delayed:
warnings = get_override(delayed, 'warnings', [])
warnings.append(msg)
Expand All @@ -97,7 +111,7 @@ def warn(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'warning')


def error(msg, delayed=None):
def error(msg: str, delayed: Any = None) -> None:
if delayed:
errors = get_override(delayed, 'errors', [])
errors.append(msg)
Expand All @@ -106,7 +120,7 @@ def error(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'error')


def reset_generator_stats():
def reset_generator_stats() -> None:
GENERATOR_STATS.reset()


Expand Down Expand Up @@ -136,7 +150,7 @@ def _get_source_location(obj):
return sourcefile, lineno


def has_override(obj, prop):
def has_override(obj: Any, prop: str) -> bool:
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
Expand All @@ -146,15 +160,15 @@ def has_override(obj, prop):
return True


def get_override(obj, prop, default=None):
def get_override(obj: Any, prop: str, default: Any = None) -> Any:
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]


def set_override(obj, prop, value):
def set_override(obj: Any, prop: str, value: Any) -> Any:
if not hasattr(obj, '_spectacular_annotation'):
obj._spectacular_annotation = {}
elif '_spectacular_annotation' not in obj.__dict__:
Expand All @@ -163,7 +177,7 @@ def set_override(obj, prop, value):
return obj


def get_view_method_names(view, schema=None):
def get_view_method_names(view, schema=None) -> List[str]:
schema = schema or view.schema
return [
item for item in dir(view) if callable(getattr(view, item)) and (
Expand Down Expand Up @@ -201,6 +215,6 @@ def wrapped_method(self, request, *args, **kwargs):
return wrapped_method


def cache(user_function):
def cache(user_function: F) -> F:
""" simple polyfill for python < 3.9 """
return functools.lru_cache(maxsize=None)(user_function)
return functools.lru_cache(maxsize=None)(user_function) # type: ignore
21 changes: 12 additions & 9 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from drf_spectacular.openapi import AutoSchema


_SchemaType = Dict[str, Any]


class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthenticationExtension']):
"""
Extension for specifying authentication schemes.
Expand All @@ -29,7 +32,7 @@ class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthentic
``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object
<https://spec.openapis.org/oas/v3.0.3#securitySchemeObject>`_
"""
_registry: List['OpenApiAuthenticationExtension'] = []
_registry: List[Type['OpenApiAuthenticationExtension']] = []

name: Union[str, List[str]]

Expand All @@ -43,7 +46,7 @@ def get_security_requirement(
return {name: [] for name in self.name}

@abstractmethod
def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[dict, List[dict]]:
def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[_SchemaType, List[_SchemaType]]:
pass # pragma: no cover


Expand All @@ -59,13 +62,13 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt
``map_serializer()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerExtension'] = []
_registry: List[Type['OpenApiSerializerExtension']] = []

def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]:
""" return str for overriding default name extraction """
return None

def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction):
def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer mapping """
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)

Expand All @@ -82,14 +85,14 @@ class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializ
``map_serializer_field()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerFieldExtension'] = []
_registry: List[Type['OpenApiSerializerFieldExtension']] = []

def get_name(self) -> Optional[str]:
""" return str for breaking out field schema into separate named component """
return None

@abstractmethod
def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction):
def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer field mapping """
pass # pragma: no cover

Expand All @@ -102,7 +105,7 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']):
``ViewSet`` et al.). The discovered original view instance can be accessed with
``self.target`` and be subclassed if desired.
"""
_registry: List['OpenApiViewExtension'] = []
_registry: List[Type['OpenApiViewExtension']] = []

@classmethod
def _load_class(cls):
Expand All @@ -129,8 +132,8 @@ class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension']
Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate
the appropriate raw dict objects.
"""
_registry: List['OpenApiFilterExtension'] = []
_registry: List[Type['OpenApiFilterExtension']] = []

@abstractmethod
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[dict]:
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[_SchemaType]:
pass # pragma: no cover
12 changes: 7 additions & 5 deletions drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings

from drf_spectacular.drainage import add_trace_message, get_override, reset_generator_stats
from drf_spectacular.drainage import (
add_trace_message, error, get_override, reset_generator_stats, warn,
)
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, error,
get_class, is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, sanitize_result_object, warn,
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, sanitize_result_object,
)
from drf_spectacular.settings import spectacular_settings

Expand Down
3 changes: 2 additions & 1 deletion drf_spectacular/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from inflection import camelize
from rest_framework.settings import api_settings

from drf_spectacular.drainage import warn
from drf_spectacular.plumbing import (
ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref, warn,
ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref,
)
from drf_spectacular.settings import spectacular_settings

Expand Down
Loading