Skip to content

Commit

Permalink
Merge pull request #52 from brettchaldecott/feat/add-multiple-session…
Browse files Browse the repository at this point in the history
…-support

feat: added multi environment support
  • Loading branch information
brettchaldecott authored Mar 4, 2025
2 parents 46875a9 + 5c8aa11 commit 49187ca
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 0 deletions.
159 changes: 159 additions & 0 deletions kinde_sdk/kinde_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def is_authenticated(self) -> bool:
self._refresh_token()
return True
return False

def is_authenticated_token(self, token_value: dict) -> dict:
if token_value:
if token_value.is_expired():
return self._refresh_token_value(token_value)
return None

def create_org(self) -> str:
return f"{self.registration_url}&is_create_org=true"
Expand All @@ -145,23 +151,50 @@ def get_claim(self, key: str, token_name: str = "access_token") -> Any:
self._decode_token_if_needed(token_name)
value = self.__decoded_tokens[token_name].get(key)
return {"name": key, "value": value}

def get_claim_token(self, token_value: dict, key: str, token_name: str = "access_token") -> Any:
if token_name not in self.TOKEN_NAMES:
raise KindeTokenException(
f"Please use only tokens from the list: {self.TOKEN_NAMES}"
)

decoded_tokens = self._decode_token_if_needed_value(token_name,token_value)
value = decoded_tokens[token_name].get(key)
return {"name": key, "value": value}

def get_permission(self, permission: str) -> Dict[str, Any]:
return {
"org_code": self.get_claim("org_code")["value"],
"is_granted": permission in self.get_claim("permissions")["value"],
}

def get_permission_token(self, token_value: dict, permission: str) -> Dict[str, Any]:
return {
"org_code": self.get_claim_token(token_value, "org_code")["value"],
"is_granted": permission in self.get_claim_token(token_value, "permissions")["value"],
}

def get_permissions(self) -> Dict[str, Any]:
return {
"org_code": self.get_claim("org_code")["value"],
"permissions": self.get_claim("permissions")["value"],
}

def get_permissions_token(self, token_value: dict) -> Dict[str, Any]:
return {
"org_code": self.get_claim_token(token_value, "org_code")["value"],
"permissions": self.get_claim_token(token_value, "permissions")["value"],
}

def get_organization(self) -> Dict[str, str]:
return {
"org_code": self.get_claim("org_code")["value"],
}

def get_organization_token(self, token_value: dict) -> Dict[str, str]:
return {
"org_code": self.get_claim_token(token_value, "org_code")["value"],
}

def get_user_details(self) -> Dict[str, str]:
return {
Expand All @@ -171,11 +204,25 @@ def get_user_details(self) -> Dict[str, str]:
"email": self.get_claim("email", "id_token")["value"],
"picture": self.get_claim("picture", "id_token")["value"],
}

def get_user_details_token(self,token_value: dict) -> Dict[str, str]:
return {
"id": self.get_claim_token(token_value, "sub","id_token")["value"],
"given_name": self.get_claim_token(token_value, "given_name", "id_token")["value"],
"family_name": self.get_claim_token(token_value, "family_name", "id_token")["value"],
"email": self.get_claim_token(token_value, "email", "id_token")["value"],
"picture": self.get_claim_token(token_value, "picture", "id_token")["value"],
}

def get_user_organizations(self) -> Dict[str, List[str]]:
return {
"org_codes": self.get_claim("org_codes", "id_token")["value"],
}

def get_user_organizations_token(self, token_value: dict) -> Dict[str, List[str]]:
return {
"org_codes": self.get_claim_token(token_value, "org_codes", "id_token")["value"],
}

def get_flag(
self, code: str, default_value: Any = None, flag_type: str = ""
Expand Down Expand Up @@ -205,15 +252,53 @@ def get_flag(
result_flag["type"] = FlagType[flag_type].value

return result_flag

def get_flag_token(
self, token_value: dict, code: str, default_value: Any = None, flag_type: str = ""
) -> Any:
flags = self.get_claim_token(token_value, "feature_flags")["value"] or {}
flag = {}

if code not in list(flags.keys()):
if default_value is None:
raise KindeRetrieveException(
f"Flag {code} was not found, and no default value has been provided"
)
else:
flag = flags[code]
if flag_type and flag.get("t") and flag_type != flag.get("t"):
raise KindeRetrieveException(
f"Flag {code} is of type {FlagType[flag.get('t')].value} - requested type {FlagType[flag_type].value}"
)

result_flag = {
"code": code,
"value": flag.get("v") if flag else default_value,
"is_default": not bool(flag),
}
flag_type = flag["t"] if flag else flag_type
if flag_type:
result_flag["type"] = FlagType[flag_type].value

return result_flag

def get_boolean_flag(self, code: str, default_value: Any = None) -> bool:
return self.get_flag(code, default_value, "b")["value"]

def get_boolean_flag_token(self, token_value: dict, code: str, default_value: Any = None) -> bool:
return self.get_flag_token(token_value, code, default_value, "b")["value"]

def get_string_flag(self, code: str, default_value: Any = None) -> str:
return self.get_flag(code, default_value, "s")["value"]

def get_string_flag_token(self, token_value: dict, code: str, default_value: Any = None) -> str:
return self.get_flag_token(token_value, code, default_value, "s")["value"]

def get_integer_flag(self, code: str, default_value: Any = None) -> int:
return self.get_flag(code, default_value, "i")["value"]

def get_integer_flag_token(self, token_value: dict, code: str, default_value: Any = None) -> int:
return self.get_flag_token(token_value, code, default_value, "i")["value"]

def call_api(self, *args, **kwargs) -> Any:
self._get_or_refresh_access_token()
Expand Down Expand Up @@ -248,6 +333,34 @@ def _decode_token_if_needed(self, token_name: str) -> None:
self.__decoded_tokens[token_name] = jwt.decode(**decode_token_params)
else:
raise KindeTokenException(f"Token {token_name} doesn't exist.")

def _decode_token_if_needed_value(self, token_name: str, token_value: dict) -> dict:
if token_name not in token_value:
if not token_value:
raise KindeTokenException(
"Access token doesn't exist.\n"
"When grant_type is CLIENT_CREDENTIALS use fetch_token().\n"
'For other grant_type use "get_login_url()" or "get_register_url()".'
)
token = token_value.get(token_name)

signing_key = self.jwks_client.get_signing_key_from_jwt(token)

if token:
decode_token_params = {
"jwt":token,
"key": signing_key.key,
"algorithms":["RS256"],
"options":{
"verify_signature": True,
"verify_exp": True,
"verify_aud": False
}
}
return jwt.decode(**decode_token_params)
else:
raise KindeTokenException(f"Token {token_name} doesn't exist.")
return token_value

def fetch_token(self, authorization_response: Optional[str] = None) -> None:
if self.grant_type == GrantType.CLIENT_CREDENTIALS:
Expand All @@ -274,6 +387,31 @@ def fetch_token(self, authorization_response: Optional[str] = None) -> None:
self.configuration.access_token = self.__access_token_obj.get("access_token")
self._clear_decoded_tokens()

def fetch_token_value(self, authorization_response: Optional[str] = None) -> dict:
if self.grant_type == GrantType.CLIENT_CREDENTIALS:
params = {"grant_type": "client_credentials"}
if self.audience:
params["audience"] = self.audience
else:
if authorization_response is None:
raise KindeConfigurationException(
'"authorization_response" parameter is required when grant_type is different than CLIENT_CREDENTIALS.'
)
params = {"authorization_response": authorization_response}
if self.grant_type == GrantType.AUTHORIZATION_CODE_WITH_PKCE:
params["code_verifier"] = self.code_verifier

access_token_obj = self.client.fetch_token(
self.token_endpoint,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Kinde-SDK": "/".join(("Python", kinde_sdk_version)),
},
**params,
)
return access_token_obj


def _get_or_refresh_access_token(self) -> None:
if self.grant_type == GrantType.CLIENT_CREDENTIALS:
if not self.__access_token_obj or self.__access_token_obj.is_expired():
Expand Down Expand Up @@ -309,6 +447,27 @@ def _refresh_token(self) -> None:
self._clear_decoded_tokens()
else:
raise KindeTokenException('"Access token" and "Refresh token" are invalid.')

def _refresh_token_value(self, token_value: dict) -> dict:
refresh_token = token_value.get("refresh_token")

if refresh_token:
token_value = self.client.refresh_token(
self.token_endpoint,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Kinde-SDK": "/".join(("Python", kinde_sdk_version)),
},
refresh_token=refresh_token,
)
if not token_value:
raise KindeTokenException(
'"Access token" and "Refresh token" are invalid.'
)

return token_value
else:
raise KindeTokenException('"Access token" and "Refresh token" are invalid.')

def _add_additional_params(self, url: str, additional_params: Optional[Dict[str, str]] = None) -> str:

Expand Down
86 changes: 86 additions & 0 deletions kinde_sdk/test/test_kinde_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def test_fetch_token_authorization_code(self):
client.fetch_token(authorization_response="https://example.com/callback?code=test_code")
self.mock_oauth2_session.return_value.fetch_token.assert_called_once()

def test_fetch_token_authorization_code_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)
self.mock_oauth2_session.return_value.fetch_token.return_value = {"access_token": "test_token"}
token_value = client.fetch_token_value(authorization_response="https://example.com/callback?code=test_code")
self.mock_oauth2_session.return_value.fetch_token.assert_called_once()

@patch('kinde_sdk.kinde_api_client.ApiClient.call_api')
def test_super_call_api_with_correct_args(self, mock_super_call_api):
client = self._create_kinde_client(GrantType.CLIENT_CREDENTIALS)
Expand Down Expand Up @@ -142,6 +148,86 @@ def mock_get_claim_side_effect(key):
mock_get_claim.assert_any_call("permissions")
self.assertEqual(mock_get_claim.call_count, 2)

def test_get_permissions_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_permissions_token({"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"]}})

expected_result = {
"org_code": "org123",
"permissions": ["read", "write", "delete"]
}
self.assertEqual(result, expected_result)

def test_get_permission_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_permission_token({"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"]}},"read")

expected_result = {
"org_code": "org123",
"is_granted": True
}
self.assertEqual(result, expected_result)


def test_get_claim_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_claim_token({"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"]}},"org_code")

expected_result = {
"name": "org_code",
"value": "org123"
}
self.assertEqual(result, expected_result)

def test_get_user_details_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_user_details_token({
"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"]}
,"id_token":{"sub":"123","given_name":"John","family_name":"Doe","email":"[email protected]","picture":"https://example.com/pic.jpg"}
})

expected_result = {
"id":"123","given_name":"John","family_name":"Doe","email":"[email protected]","picture":"https://example.com/pic.jpg"
}
self.assertEqual(result, expected_result)


def test_get_flag_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_flag_token({
"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"],"feature_flags":{"test_flag":{"v":True,"t":"b"}}}
,"id_token":{"sub":"123","given_name":"John","family_name":"Doe","email":"[email protected]","picture":"https://example.com/pic.jpg"}
},
"test_flag"
)

expected_result = {
"code":"test_flag",
"value":True,
"is_default":False,
"type":"boolean"
}
self.assertEqual(result, expected_result)

def test_get_boolean_flag_token(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)

result = client.get_boolean_flag_token({
"access_token":{"org_code":"org123","permissions": ["read", "write", "delete"],"feature_flags":{"test_flag":{"v":True,"t":"b"}}}
,"id_token":{"sub":"123","given_name":"John","family_name":"Doe","email":"[email protected]","picture":"https://example.com/pic.jpg"}
},
"test_flag"
)

expected_result = True
self.assertEqual(result, expected_result)


def test_fetch_token_headers_with_authorization_code(self):
client = self._create_kinde_client(GrantType.AUTHORIZATION_CODE)
self.mock_oauth2_session.return_value.fetch_token.return_value = {"access_token": "test_token"}
Expand Down

0 comments on commit 49187ca

Please sign in to comment.