Skip to content

Commit

Permalink
Adds session helpers (#384)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* lol camel case

* Fix log out url part

* Add tests

* Mock JWKS call to speed up tests

* linting

* Make tests and mypy happy

* make black happy too

* Forgot import

* Satisfy type checker

* 3.8 compatibility and remove print statements

* Use sequence instead of list

* Make is_valid_jwt private
  • Loading branch information
Paul Asjes authored Dec 2, 2024
1 parent 78eb99e commit 4faeb61
Show file tree
Hide file tree
Showing 12 changed files with 691 additions and 21 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021 WorkOS
Copyright (c) 2024 WorkOS

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
9 changes: 9 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
flake8
pytest==8.3.2
pytest-asyncio==0.23.8
pytest-cov==5.0.0
six==1.16.0
black==24.4.2
twine==5.1.1
mypy==1.12.0
httpx>=0.27.0
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
httpx>=0.27.0
pydantic==2.9.2
PyJWT==2.9.0
cryptography==43.0.3
20 changes: 8 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
with open(os.path.join(base_dir, "workos", "__about__.py")) as f:
exec(f.read(), about)


def read_requirements(filename):
with open(filename) as f:
return [line.strip() for line in f if line.strip() and not line.startswith("#")]


setup(
name=about["__package_name__"],
version=about["__version__"],
Expand All @@ -27,19 +33,9 @@
),
zip_safe=False,
license=about["__license__"],
install_requires=["httpx>=0.27.0", "pydantic==2.9.2"],
install_requires=read_requirements("requirements.txt"),
extras_require={
"dev": [
"flake8",
"pytest==8.3.2",
"pytest-asyncio==0.23.8",
"pytest-cov==5.0.0",
"six==1.16.0",
"black==24.4.2",
"twine==5.1.1",
"mypy==1.12.0",
"httpx>=0.27.0",
],
"dev": read_requirements("requirements-dev.txt"),
":python_version<'3.4'": ["enum34"],
},
classifiers=[
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient
from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT

from jwt import PyJWKClient
from unittest.mock import Mock, patch
from functools import wraps


def _get_test_client_setup(
http_client_class_name: str,
Expand Down Expand Up @@ -302,3 +306,19 @@ def inner(
assert request_kwargs["params"][param] == params[param]

return inner


def with_jwks_mock(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Create mock JWKS client
mock_jwks = Mock(spec=PyJWKClient)
mock_signing_key = Mock()
mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"]
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key

# Apply the mock
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
return func(*args, **kwargs)

return wrapper
308 changes: 308 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
import pytest
from unittest.mock import Mock, patch
import jwt
from jwt import PyJWKClient
from datetime import datetime, timezone

from tests.conftest import with_jwks_mock
from workos.session import Session
from workos.types.user_management.authentication_response import (
RefreshTokenAuthenticationResponse,
)
from workos.types.user_management.session import (
AuthenticateWithSessionCookieFailureReason,
AuthenticateWithSessionCookieSuccessResponse,
RefreshWithSessionCookieErrorResponse,
RefreshWithSessionCookieSuccessResponse,
)
from workos.types.user_management.user import User

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


@pytest.fixture(scope="session")
def TEST_CONSTANTS():
# Generate RSA key pair for testing
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

public_key = private_key.public_key()

# Get the private key in PEM format
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return {
"COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=",
"SESSION_DATA": "session_data",
"CLIENT_ID": "client_123",
"USER_ID": "user_123",
"SESSION_ID": "session_123",
"ORGANIZATION_ID": "organization_123",
"CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)),
"PRIVATE_KEY": private_pem,
"PUBLIC_KEY": public_key,
"TEST_TOKEN": jwt.encode(
{
"sid": "session_123",
"org_id": "organization_123",
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
"iat": int(datetime.now(timezone.utc).timestamp()),
},
private_pem,
algorithm="RS256",
),
}


@pytest.fixture
def mock_user_management():
mock = Mock()
mock.get_jwks_url.return_value = (
"https://api.workos.com/user_management/sso/jwks/client_123"
)

return mock


@with_jwks_mock
def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

assert session.client_id == TEST_CONSTANTS["CLIENT_ID"]
assert session.cookie_password is not None


@with_jwks_mock
def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management):
with pytest.raises(ValueError, match="cookie_password is required"):
Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
cookie_password="",
)


@with_jwks_mock
def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=None,
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.authenticate()

assert (
response.reason
== AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
)


@with_jwks_mock
def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.authenticate()

assert (
response.reason
== AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
)


@with_jwks_mock
def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
invalid_session_data = Session.seal_data(
{"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"]
)
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=invalid_session_data,
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.authenticate()

assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT


@with_jwks_mock
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

# Mock the session data that would be unsealed
mock_session = {
"access_token": jwt.encode(
{
"sid": TEST_CONSTANTS["SESSION_ID"],
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
"iat": int(datetime.now(timezone.utc).timestamp()),
},
TEST_CONSTANTS["PRIVATE_KEY"],
algorithm="RS256",
),
"user": {
"object": "user",
"id": TEST_CONSTANTS["USER_ID"],
"email": "[email protected]",
"email_verified": True,
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
},
"impersonator": None,
}

# Mock the JWT payload that would be decoded
mock_jwt_payload = {
"sid": TEST_CONSTANTS["SESSION_ID"],
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
}

with patch.object(Session, "unseal_data", return_value=mock_session), patch.object(
session, "_is_valid_jwt", return_value=True
), patch("jwt.decode", return_value=mock_jwt_payload), patch.object(
session.jwks,
"get_signing_key_from_jwt",
return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]),
):
response = session.authenticate()

assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)
assert response.authenticated is True
assert response.session_id == TEST_CONSTANTS["SESSION_ID"]
assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"]
assert response.role == "admin"
assert response.permissions == ["read"]
assert response.entitlements == ["feature_1"]
assert response.user.id == TEST_CONSTANTS["USER_ID"]
assert response.impersonator is None


@with_jwks_mock
def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.refresh()

assert isinstance(response, RefreshWithSessionCookieErrorResponse)
assert (
response.reason
== AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
)


@with_jwks_mock
def test_refresh_success(TEST_CONSTANTS, mock_user_management):
test_user = {
"object": "user",
"id": TEST_CONSTANTS["USER_ID"],
"email": "[email protected]",
"first_name": "Test",
"last_name": "User",
"email_verified": True,
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
}

session_data = Session.seal_data(
{"refresh_token": "refresh_token_12345", "user": test_user},
TEST_CONSTANTS["COOKIE_PASSWORD"],
)

mock_response = {
"access_token": TEST_CONSTANTS["TEST_TOKEN"],
"refresh_token": "refresh_token_123",
"sealed_session": session_data,
"user": test_user,
}

mock_user_management.authenticate_with_refresh_token.return_value = (
RefreshTokenAuthenticationResponse(**mock_response)
)

session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=session_data,
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

with patch.object(session, "_is_valid_jwt", return_value=True) as _:
with patch(
"jwt.decode",
return_value={
"sid": TEST_CONSTANTS["SESSION_ID"],
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
},
):
response = session.refresh()

assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
assert response.authenticated is True
assert response.user.id == test_user["id"]

# Verify the refresh token was used correctly
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
refresh_token="refresh_token_12345",
organization_id=None,
session={
"seal_session": True,
"cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"],
},
)


def test_seal_data(TEST_CONSTANTS):
test_data = {"test": "data"}
sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
assert isinstance(sealed, str)

# Test unsealing
unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])

assert unsealed == test_data


def test_unseal_invalid_data(TEST_CONSTANTS):
with pytest.raises(Exception): # Adjust exception type based on your implementation
Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"])
Loading

0 comments on commit 4faeb61

Please sign in to comment.