diff --git a/.env.dist b/.env.dist index f7de860..b575690 100644 --- a/.env.dist +++ b/.env.dist @@ -2,6 +2,7 @@ STATIC_ISSUER="inuits-policy-based-auth" STATIC_PRIVATE_KEY="" STATIC_PUBLIC_KEY="" TEST_API_CONFIGURATION=src/tests/integration/test_api/configuration.json +TEST_API_TOKEN_SCHEMA=src/tests/integration/test_api/token_schema.json TEST_API_SCOPES=src/tests/integration/test_api/scopes.json TEST_API_LOGS=src/tests/integration/test_api/logs.txt diff --git a/README.md b/README.md index 76d285f..6185eec 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,8 @@ def __get_class(app, auth_type, policy_module_name): def __instantiate_authentication_policy(policy_module_name, policy, logger: Logger): + token_schema = __load_token_schema() + if policy_module_name == "token_based_policies.authlib_flask_oauth2_policy": allow_anonymous_users = ( True @@ -134,14 +136,26 @@ def __instantiate_authentication_policy(policy_module_name, policy, logger: Logg ) return policy( logger, + token_schema, os.getenv("STATIC_ISSUER"), os.getenv("STATIC_PUBLIC_KEY"), + None, allow_anonymous_users, ) if policy_module_name == "token_based_policies.default_tenant_policy": - return policy(os.getenv("ROLE_SCOPE_MAPPING", os.getenv("API_SCOPES"))) + return policy( + token_schema, os.getenv("ROLE_SCOPE_MAPPING", os.getenv("API_SCOPES")) + ) return policy() + + +def __load_token_schema() -> dict: + token_schema_path = os.getenv( + "TOKEN_SCHEMA", "path/to/token_schema.json" + ) + with open(token_schema_path, "r") as token_schema: + return json.load(token_schema) ``` Now you can import the loader in app.py and pass ```policy_factory``` as an argument to it. diff --git a/src/inuits_policy_based_auth/authentication/policies/token_based_policies/authlib_flask_oauth2_policy.py b/src/inuits_policy_based_auth/authentication/policies/token_based_policies/authlib_flask_oauth2_policy.py index 446043f..a8f4a7a 100644 --- a/src/inuits_policy_based_auth/authentication/policies/token_based_policies/authlib_flask_oauth2_policy.py +++ b/src/inuits_policy_based_auth/authentication/policies/token_based_policies/authlib_flask_oauth2_policy.py @@ -24,6 +24,8 @@ class AuthlibFlaskOauth2Policy(BaseAuthenticationPolicy): ----------- logger : Logger Logger object for logging authentication events and errors. + token_schema : dict + Dict containing mappings between property <-> path.to.that.property.in.token. static_issuer : str, optional A string representing the issuer of the JWT. This parameter is required if remote token validation is not enabled. @@ -33,6 +35,8 @@ class AuthlibFlaskOauth2Policy(BaseAuthenticationPolicy): allowed_issuers : List[str], optional A list of token issuers whose tokens are allowed. If this parameter is not passed or the list is empty, all issuers are allowed. + allow_anonymous_users : bool, optional + A bool about whether anonymous users are allowed to do requests. **kwargs : dict Any additional keyword arguments to be passed to the JWTValidator constructor. """ @@ -40,10 +44,11 @@ class AuthlibFlaskOauth2Policy(BaseAuthenticationPolicy): def __init__( self, logger: Logger, + token_schema: dict, static_issuer=None, static_public_key=None, - allow_anonymous_users=False, allowed_issuers=None, + allow_anonymous_users=False, **kwargs, ): validator = JWTValidator( @@ -58,6 +63,7 @@ def __init__( self._resource_protector = resource_protector self._logger = logger + self._token_schema = token_schema self._allow_anonymous_users = allow_anonymous_users def authenticate(self, user_context, _): @@ -87,7 +93,9 @@ def authenticate(self, user_context, _): user_context.auth_objects.add_key_value_pair("token", token) flattened_token = user_context.flatten_auth_object(token) - user_context.email = flattened_token.get("email", "").lower() + user_context.email = flattened_token.get( + self._token_schema["email"], "" + ).lower() return user_context except InvalidTokenError as error: raise Unauthorized(str(error)) diff --git a/src/inuits_policy_based_auth/authentication/policies/token_based_policies/default_tenant_policy.py b/src/inuits_policy_based_auth/authentication/policies/token_based_policies/default_tenant_policy.py index 462a875..a954b44 100644 --- a/src/inuits_policy_based_auth/authentication/policies/token_based_policies/default_tenant_policy.py +++ b/src/inuits_policy_based_auth/authentication/policies/token_based_policies/default_tenant_policy.py @@ -15,11 +15,14 @@ class DefaultTenantPolicy(BaseAuthenticationPolicy): Parameters: ----------- + token_schema : dict + Dict containing mappings between property <-> path.to.that.property.in.token. role_scope_mapping_filepath : str, optional Path to a JSON file containing a mapping of scopes to their corresponding roles. """ - def __init__(self, role_scope_mapping_filepath=None): + def __init__(self, token_schema: dict, role_scope_mapping_filepath=None): + self._token_schema = token_schema self._role_scope_mapping = self.__load_role_scope_mapping( role_scope_mapping_filepath ) @@ -54,7 +57,7 @@ def authenticate(self, user_context: UserContext, _): user_context.x_tenant.id = "/" user_context.x_tenant.roles = flattened_token.get( - f"resource_access.{token['azp']}.roles", [] + self._token_schema["roles"], [] ) if self._role_scope_mapping: for role in user_context.x_tenant.roles: diff --git a/src/tests/integration/test_api/policy_loader.py b/src/tests/integration/test_api/policy_loader.py index 74fa418..5d43f5b 100644 --- a/src/tests/integration/test_api/policy_loader.py +++ b/src/tests/integration/test_api/policy_loader.py @@ -57,6 +57,8 @@ def __get_class(app, auth_type, policy_module_name): def __instantiate_authentication_policy(policy_module_name, policy, logger: Logger): + token_schema = __load_token_schema() + if policy_module_name == "token_based_policies.authlib_flask_oauth2_policy": allow_anonymous_users = ( True @@ -65,11 +67,23 @@ def __instantiate_authentication_policy(policy_module_name, policy, logger: Logg ) return policy( logger, + token_schema, os.getenv("STATIC_ISSUER"), os.getenv("STATIC_PUBLIC_KEY"), + None, allow_anonymous_users, ) if policy_module_name == "token_based_policies.default_tenant_policy": - return policy(os.getenv("ROLE_SCOPE_MAPPING", os.getenv("TEST_API_SCOPES"))) + return policy( + token_schema, os.getenv("ROLE_SCOPE_MAPPING", os.getenv("TEST_API_SCOPES")) + ) return policy() + + +def __load_token_schema() -> dict: + token_schema_path = os.getenv( + "TEST_API_TOKEN_SCHEMA", "src/tests/integration/test_api/token_schema.json" + ) + with open(token_schema_path, "r") as token_schema: + return json.load(token_schema) diff --git a/src/tests/integration/test_api/token_schema.json b/src/tests/integration/test_api/token_schema.json new file mode 100644 index 0000000..c0233d6 --- /dev/null +++ b/src/tests/integration/test_api/token_schema.json @@ -0,0 +1,4 @@ +{ + "email": "email", + "roles": "resource_access.inuits-policy-based-auth.roles" +}