diff --git a/ninja/signature/details.py b/ninja/signature/details.py index 8f026b067..4391ab774 100644 --- a/ninja/signature/details.py +++ b/ninja/signature/details.py @@ -183,9 +183,28 @@ def _create_models(self) -> TModels: return result def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]: - flatten_map = {} + flatten_map: Dict[str, Tuple[str, ...]] = {} arg_names: Any = {} for arg in args: + # Check if this is an optional union type with None default + if get_origin(arg.annotation) in UNION_TYPES: + union_args = get_args(arg.annotation) + has_none = type(None) in union_args + # If it's a union with None and the source default is None (like Query(None)), don't flatten it + if ( + has_none + and hasattr(arg.source, "default") + and arg.source.default is None + ): + name = arg.alias + if name in flatten_map: + raise ConfigError( + f"Duplicated name: '{name}' also in '{arg_names[name]}'" + ) + flatten_map[name] = (name,) + arg_names[name] = name + continue + if is_pydantic_model(arg.annotation): for name, path in self._model_flatten_map(arg.annotation, arg.alias): if name in flatten_map: @@ -207,12 +226,19 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]] def _model_flatten_map(self, model: TModel, prefix: str) -> Generator: field: FieldInfo - for attr, field in model.model_fields.items(): - field_name = field.alias or attr - name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" - if is_pydantic_model(field.annotation): - yield from self._model_flatten_map(field.annotation, name) # type: ignore - else: + if get_origin(model) in UNION_TYPES: + # If the model is a union type, process each type in the union + for arg in get_args(model): + yield from self._model_flatten_map(arg, prefix) + else: + for attr, field in model.model_fields.items(): + field_name = field.alias or attr + name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" + + if get_origin( + field.annotation + ) not in UNION_TYPES and is_pydantic_model(field.annotation): + yield from self._model_flatten_map(field.annotation, name) # type: ignore yield field_name, name def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: @@ -260,15 +286,21 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: # 2) if param name is a part of the path parameter elif name in self.path_params_names: - assert ( - default == self.signature.empty - ), f"'{name}' is a path param, default not allowed" + assert default == self.signature.empty, ( + f"'{name}' is a path param, default not allowed" + ) param_source = Path(...) # 3) if param is a collection, or annotation is part of pydantic model: elif is_collection or is_pydantic_model(annotation): if default == self.signature.empty: - param_source = Body(...) + # Check if this is a Union type that includes None - if so, None should be a valid value + if get_origin(annotation) in UNION_TYPES and type(None) in get_args( + annotation + ): + param_source = Body(None) # Make it optional with None default + else: + param_source = Body(...) else: param_source = Body(default) @@ -295,7 +327,11 @@ def is_pydantic_model(cls: Any) -> bool: # Handle Union types if origin in UNION_TYPES: - return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls)) + return any( + issubclass(arg, pydantic.BaseModel) + for arg in get_args(cls) + if arg is not type(None) + ) return issubclass(cls, pydantic.BaseModel) except TypeError: # pragma: no cover return False @@ -338,20 +374,31 @@ def detect_collection_fields( for attr in path[1:]: if hasattr(annotation_or_field, "annotation"): annotation_or_field = annotation_or_field.annotation - annotation_or_field = next( - ( - a - for a in annotation_or_field.model_fields.values() - if a.alias == attr - ), - annotation_or_field.model_fields.get(attr), - ) # pragma: no cover + + # check union types + if get_origin(annotation_or_field) in UNION_TYPES: + for arg in get_args(annotation_or_field): # pragma: no branch + found_field = next( + (a for a in arg.model_fields.values() if a.alias == attr), + arg.model_fields.get(attr), + ) + if found_field is not None: + annotation_or_field = found_field + break + else: + annotation_or_field = next( + ( + a + for a in annotation_or_field.model_fields.values() + if a.alias == attr + ), + annotation_or_field.model_fields.get(attr), + ) annotation_or_field = getattr( annotation_or_field, "outer_type_", annotation_or_field ) - # if hasattr(annotation_or_field, "annotation"): annotation_or_field = annotation_or_field.annotation if is_collection_type(annotation_or_field): diff --git a/ninja/testing/client.py b/ninja/testing/client.py index 50bd01c57..759722fff 100644 --- a/ninja/testing/client.py +++ b/ninja/testing/client.py @@ -206,7 +206,7 @@ def __init__(self, http_response: Union[HttpResponse, StreamingHttpResponse]): if self.streaming: self.content = b"".join(http_response.streaming_content) # type: ignore else: - self.content = http_response.content # type: ignore[union-attr] + self.content = http_response.content self._data = None def json(self) -> Any: diff --git a/tests/test_misc.py b/tests/test_misc.py index 355dc8883..97eb4a0c4 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -15,8 +15,12 @@ def test_is_pydantic_model(): class Model(BaseModel): x: int + class ModelNone(BaseModel): + x: int | None + assert is_pydantic_model(Model) assert is_pydantic_model("instance") is False + assert is_pydantic_model(ModelNone) def test_client(): diff --git a/tests/test_models.py b/tests/test_models.py index 034bfb0fd..c4088c749 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,10 @@ +from typing import List, Union + import pytest from pydantic import BaseModel -from ninja import Form, Query, Router +from ninja import Form, NinjaAPI, Query, Router +from ninja.errors import ConfigError from ninja.testing import TestClient @@ -9,6 +12,7 @@ class SomeModel(BaseModel): i: int s: str f: float + n: int | None = None class OtherModel(BaseModel): @@ -24,6 +28,35 @@ class SelfReference(BaseModel): SelfReference.model_rebuild() +# Union type models for testing union type handling +class PetSchema(BaseModel): + nicknames: list[str] + + +class PersonSchema(BaseModel): + name: str + age: int + pet: PetSchema | None = None + + +class NestedUnionModel(BaseModel): + nested_field: Union[PersonSchema, None] = None + simple_field: str = "default" + + +class ComplexUnionModel(BaseModel): + # Union field with pydantic models + model_union: Union[SomeModel, OtherModel] | None = None + # Simple union field + simple_union: Union[str, int] = "default" + + +# Model with non-optional union of pydantic models +class MultiModelUnion(BaseModel): + # Union of multiple pydantic models without None - + models: Union[SomeModel, OtherModel] # No default, no None + + router = Router() @@ -76,6 +109,62 @@ def view7(request, obj: OtherModel = OtherModel(x=1, y=1)): return obj +# Union type test views +@router.post("/test-union-query") +def view_union_query(request, person: PersonSchema = Query(...)): + return person + + +@router.post("/test-union-body") +def view_union_body(request, union_body: Union[SomeModel, OtherModel]): + return union_body + + +@router.post("/test-optional-union") +def view_optional_union(request, optional_model: Union[SomeModel, None] = Query(None)): + if optional_model is None: + return {"result": "none"} + return {"result": optional_model} + + +@router.post("/test-nested-union") +def view_nested_union(request, data: NestedUnionModel): + return data.model_dump() + + +@router.post("/test-complex-union") +def view_complex_union(request, data: ComplexUnionModel = Query(...)): + return data + + +# Test direct union parameter to cover _model_flatten_map +@router.post("/test-direct-union") +def view_direct_union(request, model: Union[SomeModel, OtherModel] = Query(...)): + return model + + +# Test union of pydantic models +@router.post("/test-multi-model-union") +def view_multi_model_union(request, data: MultiModelUnion): + return data.model_dump() + + +@router.post("/test-union-with-none") +def view_union_with_none(request, optional: Union[str, None] = Query(None)): + """Test Union[str, None]""" + return {"optional": optional} + + +class CollectionUnionModel(BaseModel): + items: List[str] + nested: Union[SomeModel, None] = None + + +@router.post("/test-collection-union") +def view_collection_union(request, data: CollectionUnionModel): + return data.model_dump() + + client = TestClient(router) @@ -86,7 +175,12 @@ def view7(request, obj: OtherModel = OtherModel(x=1, y=1)): ( "/test1", dict(json={"i": "1", "s": "foo", "f": "1.1"}), - {"i": 1, "s": "foo", "f": 1.1}, + {"i": 1, "s": "foo", "f": 1.1, "n": None}, + ), + ( + "/test1", + dict(json={"i": "1", "s": "foo", "f": "1.1", "n": 42}), + {"i": 1, "s": "foo", "f": 1.1, "n": 42}, ), ( "/test2", @@ -96,12 +190,15 @@ def view7(request, obj: OtherModel = OtherModel(x=1, y=1)): "other": {"x": 1, "y": 2}, } ), - {"some": {"i": 1, "s": "foo", "f": 1.1}, "other": {"x": 1, "y": 2}}, + { + "some": {"i": 1, "s": "foo", "f": 1.1, "n": None}, + "other": {"x": 1, "y": 2}, + }, ), ( "/test3", dict(json={"i": "1", "s": "foo", "f": "1.1"}), - {"i": 1, "s": "foo", "f": 1.1}, + {"i": 1, "s": "foo", "f": 1.1, "n": None}, ), ( "/test_form", @@ -133,6 +230,62 @@ def view7(request, obj: OtherModel = OtherModel(x=1, y=1)): dict(json=None), {"x": 1, "y": 1}, ), + ( + "/test-union-query?name=John&age=30", + dict(json=None), + {"name": "John", "age": 30, "pet": None}, + ), + ( + "/test-union-body", + dict(json={"i": 1, "s": "test", "f": 1.5}), + {"i": 1, "s": "test", "f": 1.5, "n": None}, + ), + ( + "/test-direct-union?i=1&s=test&f=1.5", + dict(json=None), + {"i": 1, "s": "test", "f": 1.5, "n": None}, + ), + ( + "/test-union-with-none", + dict(json=None), + {"optional": None}, + ), + ( + "/test-union-with-none?optional=test", + dict(json=None), + {"optional": "test"}, + ), + # Test collection union model + ( + "/test-collection-union", + dict(json={"items": ["a", "b"], "nested": None}), + {"items": ["a", "b"], "nested": None}, + ), + ( + "/test-collection-union", + dict(json={"items": ["x"], "nested": {"i": 5, "s": "test", "f": 2.0}}), + {"items": ["x"], "nested": {"i": 5, "s": "test", "f": 2.0, "n": None}}, + ), + ( + "/test-multi-model-union", + dict(json={"models": {"i": 1, "s": "test", "f": 1.5}}), + {"models": {"i": 1, "s": "test", "f": 1.5, "n": None}}, + ), + ( + "/test-optional-union", + dict(json=None), + {"result": "none"}, + ), + ( + "/test-nested-union", + dict(json={"nested_field": None, "simple_field": "test"}), + {"nested_field": None, "simple_field": "test"}, + ), + ( + "/test-complex-union?simple_union=42", + dict(json=None), + {"model_union": None, "simple_union": "42"}, + ), ], # fmt: on ) @@ -148,3 +301,33 @@ def test_invalid_body(): assert response.json() == { "detail": "Cannot parse request body", } + + +def test_union_query_name_collision(): + """Test that duplicate union parameter names with Query(None) raise ConfigError.""" + + with pytest.raises(ConfigError, match=r"Duplicated name.*person"): + api = NinjaAPI() + router_test = Router() + + @router_test.post("/collision-test") + def collision_endpoint( + person1: Union[PersonSchema, None] = Query(None, alias="person"), + person2: Union[PersonSchema, None] = Query(None, alias="person"), + ): + return {"result": "should not reach here"} + + api.add_router("/test", router_test) + + +def test_union_with_none_body_param(): + """Test Union[Model, None] parameter""" + + test_router = Router() + + @test_router.post("/test-union-none-body") + def test_union_none_body(request, data: Union[SomeModel, None]): + return data.model_dump() if data else {"result": "none"} + + # Verify the router was created successfully and has one registered operation + assert len(test_router.path_operations) == 1