From 0073127de5662bf23fbeca041d0833507ce94aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89loi=20Rivard?= Date: Mon, 2 Dec 2024 10:44:04 +0100 Subject: [PATCH] feat: Implement Resource.get_extension_model --- doc/changelog.rst | 9 ++++++- scim2_models/rfc7643/resource.py | 10 +++++++- tests/test_resource_extension.py | 41 ++++++++++++++++++++++++++++++-- 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index f75aef6..81682c6 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,13 @@ Changelog ========= +[0.2.9] - Unreleased +-------------------- + +Added +^^^^^ +- Implement :meth:`Resource.get_extension_model `. + [0.2.8] - 2024-12-02 -------------------- @@ -13,7 +20,7 @@ Added Added ^^^^^ -- Implement :meth:`ResourceType.from_resource`. +- Implement :meth:`ResourceType.from_resource `. [0.2.6] - 2024-11-29 -------------------- diff --git a/scim2_models/rfc7643/resource.py b/scim2_models/rfc7643/resource.py index 69af464..8a9d174 100644 --- a/scim2_models/rfc7643/resource.py +++ b/scim2_models/rfc7643/resource.py @@ -177,7 +177,7 @@ def __setitem__(self, item: Any, value: "Resource"): setattr(self, item.__name__, value) @classmethod - def get_extension_models(cls) -> dict[str, type]: + def get_extension_models(cls) -> dict[str, type[Extension]]: """Return extension a dict associating extension models with their schemas.""" extension_models = cls.__pydantic_generic_metadata__.get("args", []) extension_models = ( @@ -191,6 +191,14 @@ def get_extension_models(cls) -> dict[str, type]: } return by_schema + @classmethod + def get_extension_model(cls, name_or_schema) -> Optional[type[Extension]]: + """Return an extension by its name or schema.""" + for schema, extension in cls.get_extension_models().items(): + if schema == name_or_schema or extension.__name__ == name_or_schema: + return extension + return None + @staticmethod def get_by_schema( resource_types: list[type[BaseModel]], schema: str, with_extensions=True diff --git a/tests/test_resource_extension.py b/tests/test_resource_extension.py index 950a9f0..b4a4934 100644 --- a/tests/test_resource_extension.py +++ b/tests/test_resource_extension.py @@ -201,7 +201,7 @@ def test_invalid_setitem(): class SuperHero(Extension): - schemas: list[str] = ["urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"] + schemas: list[str] = ["example:extensions:SuperHero"] superpower: Optional[str] = None """The superhero superpower.""" @@ -217,8 +217,9 @@ def test_multiple_extensions_union(): "schemas": [ "urn:ietf:params:scim:schemas:core:2.0:User", "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User", + "example:extensions:SuperHero", ], - "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": { + "example:extensions:SuperHero": { "superpower": "flight", }, } @@ -268,3 +269,39 @@ def test_validate_items_without_extension(): User[EnterpriseUser].model_validate( payload, scim_ctx=Context.RESOURCE_CREATION_RESPONSE ) + + +def test_get_extension_model(): + assert User[EnterpriseUser].get_extension_model("EnterpriseUser") == EnterpriseUser + assert ( + User[EnterpriseUser].get_extension_model( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + == EnterpriseUser + ) + + assert ( + User[Union[EnterpriseUser, SuperHero]].get_extension_model("EnterpriseUser") + == EnterpriseUser + ) + assert ( + User[Union[EnterpriseUser, SuperHero]].get_extension_model( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + == EnterpriseUser + ) + + assert User[SuperHero].get_extension_model("EnterpriseUser") is None + assert ( + User[SuperHero].get_extension_model( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + is None + ) + assert User.get_extension_model("EnterpriseUser") is None + assert ( + User.get_extension_model( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + is None + )