diff --git a/ninja/signature/details.py b/ninja/signature/details.py index 8f026b067..57ffb6e1e 100644 --- a/ninja/signature/details.py +++ b/ninja/signature/details.py @@ -187,8 +187,13 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]] arg_names: Any = {} for arg in args: if is_pydantic_model(arg.annotation): - for name, path in self._model_flatten_map(arg.annotation, arg.alias): - if name in flatten_map: + for name, path, is_union_descendant in self._model_flatten_map( + arg.annotation, arg.alias + ): + model = arg.annotation + if get_origin(model) is Annotated: + model = get_args(model)[0] + if not is_union_descendant and name in flatten_map: raise ConfigError( f"Duplicated name: '{name}' in params: '{arg_names[name]}' & '{arg.name}'" ) @@ -205,15 +210,26 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]] return flatten_map - def _model_flatten_map(self, model: TModel, prefix: str) -> Generator: + def _model_flatten_map( + self, model: TModel, prefix: str, is_union_descendant: bool = False + ) -> Generator[Tuple[str, str, bool], None, None]: field: FieldInfo - for attr, field in model.model_fields.items(): - field_name = field.alias or attr - name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" - if is_pydantic_model(field.annotation): - yield from self._model_flatten_map(field.annotation, name) # type: ignore - else: - yield field_name, name + if get_origin(model) is Annotated: + model = get_args(model)[0] + if get_origin(model) in UNION_TYPES: + # If the model is a union type, process each type in the union + for arg in get_args(model): + yield from self._model_flatten_map(arg, prefix, True) + else: + for attr, field in model.model_fields.items(): + field_name = field.alias or attr + name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" + if is_pydantic_model(field.annotation): + yield from self._model_flatten_map( + field.annotation, name, is_union_descendant + ) # type: ignore + else: + yield field_name, name, is_union_descendant def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: # _EMPTY = self.signature.empty @@ -336,24 +352,40 @@ def detect_collection_fields( for path in (p for p in flatten_map.values() if len(p) > 1): annotation_or_field: Any = args_d[path[0]].annotation for attr in path[1:]: - if hasattr(annotation_or_field, "annotation"): - annotation_or_field = annotation_or_field.annotation - annotation_or_field = next( - ( - a - for a in annotation_or_field.model_fields.values() - if a.alias == attr - ), - annotation_or_field.model_fields.get(attr), - ) # pragma: no cover + if get_origin(annotation_or_field) is Annotated: + annotation_or_field = get_args(annotation_or_field)[0] + + # check union types + if get_origin(annotation_or_field) in UNION_TYPES: + for arg in get_args(annotation_or_field): + annotation_or_field = _detect_collection_fields( + arg, attr, path, result + ) + else: + annotation_or_field = _detect_collection_fields( + annotation_or_field, attr, path, result + ) + return result - annotation_or_field = getattr( - annotation_or_field, "outer_type_", annotation_or_field - ) - # if hasattr(annotation_or_field, "annotation"): - annotation_or_field = annotation_or_field.annotation +def _detect_collection_fields( + annotation_or_field: Any, + attr: str, + path: Tuple[str, ...], + result: List[Any], +) -> Any: + annotation_or_field = next( + (a for a in annotation_or_field.model_fields.values() if a.alias == attr), + annotation_or_field.model_fields.get(attr), + ) # pragma: no cover - if is_collection_type(annotation_or_field): - result.append(path[-1]) - return result + annotation_or_field = getattr( + annotation_or_field, "outer_type_", annotation_or_field + ) + + # if hasattr(annotation_or_field, "annotation"): + annotation_or_field = annotation_or_field.annotation + + if is_collection_type(annotation_or_field): + return result.append(path[-1]) + return annotation_or_field diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 7e9ae0de2..1c9a3764b 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -3,7 +3,7 @@ from pydantic import Field from typing_extensions import Annotated, Literal -from ninja import NinjaAPI, Schema +from ninja import NinjaAPI, Query, Schema from ninja.testing import TestClient @@ -37,21 +37,43 @@ def create_example_regular(request, payload: RegularUnion): return {"data": payload.model_dump(), "type": payload.__class__.__name__} +@api.get("/descr-union") +def get_example(request, payload: UnionDiscriminator = Query(...)): + return {} + + +@api.get("/regular-union") +def get_example_regular(request, payload: RegularUnion = Query(...)): + return {} + + client = TestClient(api) def test_schema(): schema = api.get_openapi_schema() - detail1 = schema["paths"]["/api/descr-union"]["post"]["requestBody"]["content"][ - "application/json" - ]["schema"] - detail2 = schema["paths"]["/api/regular-union"]["post"]["requestBody"]["content"][ - "application/json" - ]["schema"] + post_detail1 = schema["paths"]["/api/descr-union"]["post"]["requestBody"][ + "content" + ]["application/json"]["schema"] + post_detail2 = schema["paths"]["/api/regular-union"]["post"]["requestBody"][ + "content" + ]["application/json"]["schema"] + get_detail1 = schema["paths"]["/api/descr-union"]["get"]["parameters"][0]["schema"] + get_detail2 = schema["paths"]["/api/regular-union"]["get"]["parameters"][0][ + "schema" + ] # First method should have 'discriminator' in OpenAPI api - assert "discriminator" in detail1 - assert detail1["discriminator"] == { + assert "discriminator" in post_detail1 + assert "discriminator" in get_detail1 + assert post_detail1["discriminator"] == { + "mapping": { + "ONE": "#/components/schemas/Example1", + "TWO": "#/components/schemas/Example2", + }, + "propertyName": "label", + } + assert get_detail1["discriminator"] == { "mapping": { "ONE": "#/components/schemas/Example1", "TWO": "#/components/schemas/Example2", @@ -60,7 +82,8 @@ def test_schema(): } # Second method should NOT have 'discriminator' - assert "discriminator" not in detail2 + assert "discriminator" not in post_detail2 + assert "discriminator" not in get_detail2 def test_annotated_union_with_discriminator(): @@ -108,3 +131,34 @@ def test_regular_union(): "data": {"label": "TWO", "value": 123}, "type": "Example2", } + + +def test_annotated_union_with_discriminator_get(): + # Test Example1 + response = client.get( + "/descr-union", + query_params={"label": "ONE", "value": "42"}, + ) + assert response.status_code == 200 + + # Test Example2 + response = client.get( + "/descr-union", + query_params={"label": "TWO", "value": "42"}, + ) + assert response.status_code == 200 + + +def test_regular_union_get(): + # Test that regular unions still work + response = client.get( + "/regular-union", + query_params={"label": "ONE", "value": "2025"}, + ) + assert response.status_code == 200 + + response = client.get( + "/regular-union", + query_params={"label": "TWO", "value": 123}, + ) + assert response.status_code == 200