diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 2b03db45..cd6d1b1f 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -81,6 +81,16 @@ class Choices: # type: ignore else: CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore +if hasattr(typing, "NotRequired"): + NotRequired = typing.NotRequired +else: + from typing_extensions import NotRequired + +if hasattr(typing, "Required"): + Required = typing.Required +else: + from typing_extensions import Required + T = TypeVar('T') @@ -1216,13 +1226,39 @@ def _resolve_typeddict(hint): required = None if hasattr(hint, '__required_keys__'): - required = [h for h in hint.__required_keys__] + required = {h for h in hint.__required_keys__} + + properties = {} + + for k, v in get_type_hints(hint).items(): + origin, args = _get_type_hint_origin(v) + + # Unwrap Required and NotRequired, as get_type_hints() does + # not understand them + if origin == NotRequired or origin == Required: + # If we are on Python3.11 or later, or we are on an earlier + # version of python and are explicitly using typing_extensions.TypedDict, + # then the value of required should already be set correctly + # However, it does present a bit of a foot-gun, so we + # have repeated the logic here as a safeguard in the case + # that a user is on Python version 3.9 or 3.10 is not using + # typing_extensions.TypedDict + if origin == Required: + required.add(k) + else: + required.discard(k) + + if len(args) != 1: + raise UnableToProceedError() + + properties[k] = resolve_type_hint(args[0]) + + else: + properties[k] = resolve_type_hint(v) return build_object_type( - properties={ - k: resolve_type_hint(v) for k, v in get_type_hints(hint).items() - }, - required=required, + properties=properties, + required=None if required is None else list(required), description=get_doc(hint), ) diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py index d2ca7ef5..1f6780fc 100644 --- a/tests/test_plumbing.py +++ b/tests/test_plumbing.py @@ -26,6 +26,16 @@ from drf_spectacular.validation import validate_schema from tests import generate_schema +if hasattr(typing, "NotRequired"): + NotRequired = typing.NotRequired +else: + from typing_extensions import NotRequired + +if hasattr(typing, "Required"): + Required = typing.Required +else: + from typing_extensions import Required + def test_get_list_serializer_preserves_context(): serializer = serializers.Serializer(context={"foo": "bar"}) @@ -240,6 +250,12 @@ class TD4Optional(TypedDict, total=False): class TD4(TD4Optional): """A test description2""" b: bool + + class TD5(TypedDict): + """A test description3""" + a: NotRequired[str] + b: Required[bool] + TYPE_HINT_TEST_PARAMS.append(( TD1, { @@ -277,6 +293,18 @@ class TD4(TD4Optional): 'required': ['b'], }) ) + TYPE_HINT_TEST_PARAMS.append(( + TD5, + { + 'type': 'object', + 'description': 'A test description3', + 'properties': { + 'a': {'type': 'string'}, + 'b': {'type': 'boolean'} + }, + 'required': ['b'], + }) + ) else: TYPE_HINT_TEST_PARAMS.append(( TD1,