Skip to content

Commit

Permalink
improve mypy typing #600
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed May 24, 2022
1 parent 13ea8e2 commit dee6328
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 103 deletions.
41 changes: 23 additions & 18 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import functools
import sys
from collections import defaultdict
from typing import DefaultDict
from typing import Any, Callable, DefaultDict, List, TypeVar

if sys.version_info >= (3, 8):
from typing import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, TypedDict, _TypedDictMeta,
)
else:
from typing_extensions import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, TypedDict, _TypedDictMeta,
)
from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401

if sys.version_info >= (3, 10):
from typing import TypeGuard # noqa: F401
else:
from typing_extensions import TypeGuard # noqa: F401

F = TypeVar('F', bound=Callable[..., Any])


class GeneratorStats:
Expand All @@ -35,19 +40,19 @@ def silence(self):
finally:
self.silent = tmp

def reset(self):
def reset(self) -> None:
self._warn_cache.clear()
self._error_cache.clear()

def emit(self, msg, severity):
def emit(self, msg: str, severity: str) -> None:
assert severity in ['warning', 'error']
msg = _get_current_trace() + str(msg)
cache = self._warn_cache if severity == 'warning' else self._error_cache
if not self.silent and msg not in cache:
print(f'{severity.capitalize()} #{len(cache)}: {msg}', file=sys.stderr)
cache[msg] += 1

def emit_summary(self):
def emit_summary(self) -> None:
if not self.silent and (self._warn_cache or self._error_cache):
print(
f'\nSchema generation summary:\n'
Expand All @@ -60,7 +65,7 @@ def emit_summary(self):
GENERATOR_STATS = GeneratorStats()


def warn(msg, delayed=None):
def warn(msg: str, delayed: Any = None):
if delayed:
warnings = get_override(delayed, 'warnings', [])
warnings.append(msg)
Expand All @@ -69,7 +74,7 @@ def warn(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'warning')


def error(msg, delayed=None):
def error(msg: str, delayed: Any = None):
if delayed:
errors = get_override(delayed, 'errors', [])
errors.append(msg)
Expand All @@ -78,15 +83,15 @@ def error(msg, delayed=None):
GENERATOR_STATS.emit(msg, 'error')


def reset_generator_stats():
def reset_generator_stats() -> None:
GENERATOR_STATS.reset()


_TRACES = []


@contextlib.contextmanager
def add_trace_message(trace_message):
def add_trace_message(trace_message: str):
"""
Adds a message to be used as a prefix when emitting warnings and errors.
"""
Expand All @@ -95,11 +100,11 @@ def add_trace_message(trace_message):
_TRACES.pop()


def _get_current_trace():
def _get_current_trace() -> str:
return ''.join(f"{trace}: " for trace in _TRACES if trace)


def has_override(obj, prop):
def has_override(obj: Any, prop: str) -> bool:
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
Expand All @@ -109,15 +114,15 @@ def has_override(obj, prop):
return True


def get_override(obj, prop, default=None):
def get_override(obj: Any, prop: str, default: Any = None) -> Any:
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]


def set_override(obj, prop, value):
def set_override(obj: Any, prop: str, value: Any) -> Any:
if not hasattr(obj, '_spectacular_annotation'):
obj._spectacular_annotation = {}
elif '_spectacular_annotation' not in obj.__dict__:
Expand All @@ -126,7 +131,7 @@ def set_override(obj, prop, value):
return obj


def get_view_method_names(view, schema=None):
def get_view_method_names(view, schema=None) -> List[str]:
schema = schema or view.schema
return [
item for item in dir(view) if callable(getattr(view, item)) and (
Expand Down Expand Up @@ -164,6 +169,6 @@ def wrapped_method(self, request, *args, **kwargs):
return wrapped_method


def cache(user_function):
def cache(user_function: F) -> F:
""" simple polyfill for python < 3.9 """
return functools.lru_cache(maxsize=None)(user_function)
return functools.lru_cache(maxsize=None)(user_function) # type: ignore
10 changes: 5 additions & 5 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthentic
``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object
<https://spec.openapis.org/oas/v3.0.3#securitySchemeObject>`_
"""
_registry: List['OpenApiAuthenticationExtension'] = []
_registry: List[Type['OpenApiAuthenticationExtension']] = []

name: Union[str, List[str]]

Expand Down Expand Up @@ -53,7 +53,7 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt
``map_serializer()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerExtension'] = []
_registry: List[Type['OpenApiSerializerExtension']] = []

def get_name(self) -> Optional[str]:
""" return str for overriding default name extraction """
Expand All @@ -76,7 +76,7 @@ class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializ
``map_serializer_field()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schemaObject>`_.
"""
_registry: List['OpenApiSerializerFieldExtension'] = []
_registry: List[Type['OpenApiSerializerFieldExtension']] = []

def get_name(self) -> Optional[str]:
""" return str for breaking out field schema into separate named component """
Expand All @@ -96,7 +96,7 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']):
``ViewSet`` et al.). The discovered original view instance can be accessed with
``self.target`` and be subclassed if desired.
"""
_registry: List['OpenApiViewExtension'] = []
_registry: List[Type['OpenApiViewExtension']] = []

@classmethod
def _load_class(cls):
Expand All @@ -123,7 +123,7 @@ class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension']
Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate
the appropriate raw dict objects.
"""
_registry: List['OpenApiFilterExtension'] = []
_registry: List[Type['OpenApiFilterExtension']] = []

@abstractmethod
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[dict]:
Expand Down
2 changes: 1 addition & 1 deletion drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings

Expand Down
Loading

0 comments on commit dee6328

Please sign in to comment.