Skip to content

Commit 4faeb61

Browse files
author
Paul Asjes
authored
Adds session helpers (#384)
* WIP * WIP * lol camel case * Fix log out url part * Add tests * Mock JWKS call to speed up tests * linting * Make tests and mypy happy * make black happy too * Forgot import * Satisfy type checker * 3.8 compatibility and remove print statements * Use sequence instead of list * Make is_valid_jwt private
1 parent 78eb99e commit 4faeb61

12 files changed

+691
-21
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021 WorkOS
3+
Copyright (c) 2024 WorkOS
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

requirements-dev.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
flake8
2+
pytest==8.3.2
3+
pytest-asyncio==0.23.8
4+
pytest-cov==5.0.0
5+
six==1.16.0
6+
black==24.4.2
7+
twine==5.1.1
8+
mypy==1.12.0
9+
httpx>=0.27.0

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
httpx>=0.27.0
2+
pydantic==2.9.2
3+
PyJWT==2.9.0
4+
cryptography==43.0.3

setup.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
with open(os.path.join(base_dir, "workos", "__about__.py")) as f:
1111
exec(f.read(), about)
1212

13+
14+
def read_requirements(filename):
15+
with open(filename) as f:
16+
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
17+
18+
1319
setup(
1420
name=about["__package_name__"],
1521
version=about["__version__"],
@@ -27,19 +33,9 @@
2733
),
2834
zip_safe=False,
2935
license=about["__license__"],
30-
install_requires=["httpx>=0.27.0", "pydantic==2.9.2"],
36+
install_requires=read_requirements("requirements.txt"),
3137
extras_require={
32-
"dev": [
33-
"flake8",
34-
"pytest==8.3.2",
35-
"pytest-asyncio==0.23.8",
36-
"pytest-cov==5.0.0",
37-
"six==1.16.0",
38-
"black==24.4.2",
39-
"twine==5.1.1",
40-
"mypy==1.12.0",
41-
"httpx>=0.27.0",
42-
],
38+
"dev": read_requirements("requirements-dev.txt"),
4339
":python_version<'3.4'": ["enum34"],
4440
},
4541
classifiers=[

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient
2525
from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT
2626

27+
from jwt import PyJWKClient
28+
from unittest.mock import Mock, patch
29+
from functools import wraps
30+
2731

2832
def _get_test_client_setup(
2933
http_client_class_name: str,
@@ -302,3 +306,19 @@ def inner(
302306
assert request_kwargs["params"][param] == params[param]
303307

304308
return inner
309+
310+
311+
def with_jwks_mock(func):
312+
@wraps(func)
313+
def wrapper(*args, **kwargs):
314+
# Create mock JWKS client
315+
mock_jwks = Mock(spec=PyJWKClient)
316+
mock_signing_key = Mock()
317+
mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"]
318+
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
319+
320+
# Apply the mock
321+
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
322+
return func(*args, **kwargs)
323+
324+
return wrapper

tests/test_session.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
import jwt
4+
from jwt import PyJWKClient
5+
from datetime import datetime, timezone
6+
7+
from tests.conftest import with_jwks_mock
8+
from workos.session import Session
9+
from workos.types.user_management.authentication_response import (
10+
RefreshTokenAuthenticationResponse,
11+
)
12+
from workos.types.user_management.session import (
13+
AuthenticateWithSessionCookieFailureReason,
14+
AuthenticateWithSessionCookieSuccessResponse,
15+
RefreshWithSessionCookieErrorResponse,
16+
RefreshWithSessionCookieSuccessResponse,
17+
)
18+
from workos.types.user_management.user import User
19+
20+
from cryptography.hazmat.primitives import serialization
21+
from cryptography.hazmat.primitives.asymmetric import rsa
22+
23+
24+
@pytest.fixture(scope="session")
25+
def TEST_CONSTANTS():
26+
# Generate RSA key pair for testing
27+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
28+
29+
public_key = private_key.public_key()
30+
31+
# Get the private key in PEM format
32+
private_pem = private_key.private_bytes(
33+
encoding=serialization.Encoding.PEM,
34+
format=serialization.PrivateFormat.PKCS8,
35+
encryption_algorithm=serialization.NoEncryption(),
36+
)
37+
38+
return {
39+
"COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=",
40+
"SESSION_DATA": "session_data",
41+
"CLIENT_ID": "client_123",
42+
"USER_ID": "user_123",
43+
"SESSION_ID": "session_123",
44+
"ORGANIZATION_ID": "organization_123",
45+
"CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)),
46+
"PRIVATE_KEY": private_pem,
47+
"PUBLIC_KEY": public_key,
48+
"TEST_TOKEN": jwt.encode(
49+
{
50+
"sid": "session_123",
51+
"org_id": "organization_123",
52+
"role": "admin",
53+
"permissions": ["read"],
54+
"entitlements": ["feature_1"],
55+
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
56+
"iat": int(datetime.now(timezone.utc).timestamp()),
57+
},
58+
private_pem,
59+
algorithm="RS256",
60+
),
61+
}
62+
63+
64+
@pytest.fixture
65+
def mock_user_management():
66+
mock = Mock()
67+
mock.get_jwks_url.return_value = (
68+
"https://api.workos.com/user_management/sso/jwks/client_123"
69+
)
70+
71+
return mock
72+
73+
74+
@with_jwks_mock
75+
def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
76+
session = Session(
77+
user_management=mock_user_management,
78+
client_id=TEST_CONSTANTS["CLIENT_ID"],
79+
session_data=TEST_CONSTANTS["SESSION_DATA"],
80+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
81+
)
82+
83+
assert session.client_id == TEST_CONSTANTS["CLIENT_ID"]
84+
assert session.cookie_password is not None
85+
86+
87+
@with_jwks_mock
88+
def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management):
89+
with pytest.raises(ValueError, match="cookie_password is required"):
90+
Session(
91+
user_management=mock_user_management,
92+
client_id=TEST_CONSTANTS["CLIENT_ID"],
93+
session_data=TEST_CONSTANTS["SESSION_DATA"],
94+
cookie_password="",
95+
)
96+
97+
98+
@with_jwks_mock
99+
def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
100+
session = Session(
101+
user_management=mock_user_management,
102+
client_id=TEST_CONSTANTS["CLIENT_ID"],
103+
session_data=None,
104+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
105+
)
106+
107+
response = session.authenticate()
108+
109+
assert (
110+
response.reason
111+
== AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
112+
)
113+
114+
115+
@with_jwks_mock
116+
def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
117+
session = Session(
118+
user_management=mock_user_management,
119+
client_id=TEST_CONSTANTS["CLIENT_ID"],
120+
session_data="invalid_session_data",
121+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
122+
)
123+
124+
response = session.authenticate()
125+
126+
assert (
127+
response.reason
128+
== AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
129+
)
130+
131+
132+
@with_jwks_mock
133+
def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
134+
invalid_session_data = Session.seal_data(
135+
{"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"]
136+
)
137+
session = Session(
138+
user_management=mock_user_management,
139+
client_id=TEST_CONSTANTS["CLIENT_ID"],
140+
session_data=invalid_session_data,
141+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
142+
)
143+
144+
response = session.authenticate()
145+
146+
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
147+
148+
149+
@with_jwks_mock
150+
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
151+
session = Session(
152+
user_management=mock_user_management,
153+
client_id=TEST_CONSTANTS["CLIENT_ID"],
154+
session_data=TEST_CONSTANTS["SESSION_DATA"],
155+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
156+
)
157+
158+
# Mock the session data that would be unsealed
159+
mock_session = {
160+
"access_token": jwt.encode(
161+
{
162+
"sid": TEST_CONSTANTS["SESSION_ID"],
163+
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
164+
"role": "admin",
165+
"permissions": ["read"],
166+
"entitlements": ["feature_1"],
167+
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
168+
"iat": int(datetime.now(timezone.utc).timestamp()),
169+
},
170+
TEST_CONSTANTS["PRIVATE_KEY"],
171+
algorithm="RS256",
172+
),
173+
"user": {
174+
"object": "user",
175+
"id": TEST_CONSTANTS["USER_ID"],
176+
"email": "[email protected]",
177+
"email_verified": True,
178+
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
179+
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
180+
},
181+
"impersonator": None,
182+
}
183+
184+
# Mock the JWT payload that would be decoded
185+
mock_jwt_payload = {
186+
"sid": TEST_CONSTANTS["SESSION_ID"],
187+
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
188+
"role": "admin",
189+
"permissions": ["read"],
190+
"entitlements": ["feature_1"],
191+
}
192+
193+
with patch.object(Session, "unseal_data", return_value=mock_session), patch.object(
194+
session, "_is_valid_jwt", return_value=True
195+
), patch("jwt.decode", return_value=mock_jwt_payload), patch.object(
196+
session.jwks,
197+
"get_signing_key_from_jwt",
198+
return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]),
199+
):
200+
response = session.authenticate()
201+
202+
assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)
203+
assert response.authenticated is True
204+
assert response.session_id == TEST_CONSTANTS["SESSION_ID"]
205+
assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"]
206+
assert response.role == "admin"
207+
assert response.permissions == ["read"]
208+
assert response.entitlements == ["feature_1"]
209+
assert response.user.id == TEST_CONSTANTS["USER_ID"]
210+
assert response.impersonator is None
211+
212+
213+
@with_jwks_mock
214+
def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
215+
session = Session(
216+
user_management=mock_user_management,
217+
client_id=TEST_CONSTANTS["CLIENT_ID"],
218+
session_data="invalid_session_data",
219+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
220+
)
221+
222+
response = session.refresh()
223+
224+
assert isinstance(response, RefreshWithSessionCookieErrorResponse)
225+
assert (
226+
response.reason
227+
== AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
228+
)
229+
230+
231+
@with_jwks_mock
232+
def test_refresh_success(TEST_CONSTANTS, mock_user_management):
233+
test_user = {
234+
"object": "user",
235+
"id": TEST_CONSTANTS["USER_ID"],
236+
"email": "[email protected]",
237+
"first_name": "Test",
238+
"last_name": "User",
239+
"email_verified": True,
240+
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
241+
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
242+
}
243+
244+
session_data = Session.seal_data(
245+
{"refresh_token": "refresh_token_12345", "user": test_user},
246+
TEST_CONSTANTS["COOKIE_PASSWORD"],
247+
)
248+
249+
mock_response = {
250+
"access_token": TEST_CONSTANTS["TEST_TOKEN"],
251+
"refresh_token": "refresh_token_123",
252+
"sealed_session": session_data,
253+
"user": test_user,
254+
}
255+
256+
mock_user_management.authenticate_with_refresh_token.return_value = (
257+
RefreshTokenAuthenticationResponse(**mock_response)
258+
)
259+
260+
session = Session(
261+
user_management=mock_user_management,
262+
client_id=TEST_CONSTANTS["CLIENT_ID"],
263+
session_data=session_data,
264+
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
265+
)
266+
267+
with patch.object(session, "_is_valid_jwt", return_value=True) as _:
268+
with patch(
269+
"jwt.decode",
270+
return_value={
271+
"sid": TEST_CONSTANTS["SESSION_ID"],
272+
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
273+
"role": "admin",
274+
"permissions": ["read"],
275+
"entitlements": ["feature_1"],
276+
},
277+
):
278+
response = session.refresh()
279+
280+
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
281+
assert response.authenticated is True
282+
assert response.user.id == test_user["id"]
283+
284+
# Verify the refresh token was used correctly
285+
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
286+
refresh_token="refresh_token_12345",
287+
organization_id=None,
288+
session={
289+
"seal_session": True,
290+
"cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"],
291+
},
292+
)
293+
294+
295+
def test_seal_data(TEST_CONSTANTS):
296+
test_data = {"test": "data"}
297+
sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
298+
assert isinstance(sealed, str)
299+
300+
# Test unsealing
301+
unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
302+
303+
assert unsealed == test_data
304+
305+
306+
def test_unseal_invalid_data(TEST_CONSTANTS):
307+
with pytest.raises(Exception): # Adjust exception type based on your implementation
308+
Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"])

0 commit comments

Comments
 (0)