Skip to content

Commit 19c5720

Browse files
authored
Merge pull request #155 from roubaeli/master
Include user claims in refresh tokens
2 parents 02c9fc6 + 664f8b4 commit 19c5720

File tree

6 files changed

+73
-5
lines changed

6 files changed

+73
-5
lines changed

docs/options.rst

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ General Options:
3939
Defaults to ``'identity'`` for legacy reasons.
4040
``JWT_USER_CLAIMS`` Claim in the tokens that is used to store user claims.
4141
Defaults to ``'user_claims'``.
42+
``JWT_CLAIMS_IN_REFRESH_TOKEN`` If user claims should be included in refresh tokens.
43+
Defaults to ``False``.
4244
================================= =========================================
4345

4446

flask_jwt_extended/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ def identity_claim_key(self):
247247
def user_claims_key(self):
248248
return current_app.config['JWT_USER_CLAIMS']
249249

250+
@property
251+
def user_claims_in_refresh_token(self):
252+
return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN']
253+
250254
@property
251255
def exempt_methods(self):
252256
return {"OPTIONS"}

flask_jwt_extended/jwt_manager.py

+9
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ def _set_default_configuration_options(app):
187187
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
188188
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
189189

190+
app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)
191+
190192
def user_claims_loader(self, callback):
191193
"""
192194
This decorator sets the callback function for adding custom claims to an
@@ -375,13 +377,20 @@ def _create_refresh_token(self, identity, expires_delta=None):
375377
if expires_delta is None:
376378
expires_delta = config.refresh_expires
377379

380+
if config.user_claims_in_refresh_token:
381+
user_claims = self._user_claims_callback(identity)
382+
else:
383+
user_claims = None
384+
378385
refresh_token = encode_refresh_token(
379386
identity=self._user_identity_callback(identity),
380387
secret=config.encode_key,
381388
algorithm=config.algorithm,
382389
expires_delta=expires_delta,
390+
user_claims=user_claims,
383391
csrf=config.csrf_protect,
384392
identity_claim_key=config.identity_claim_key,
393+
user_claims_key=config.user_claims_key,
385394
json_encoder=config.json_encoder
386395
)
387396
return refresh_token

flask_jwt_extended/tokens.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
7777
json_encoder=json_encoder)
7878

7979

80-
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
81-
identity_claim_key, json_encoder=None):
80+
def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
81+
csrf, identity_claim_key, user_claims_key,
82+
json_encoder=None):
8283
"""
8384
Creates a new encoded (utf-8) refresh token.
8485
@@ -88,15 +89,23 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
8889
:param expires_delta: How far in the future this token should expire
8990
(set to False to disable expiration)
9091
:type expires_delta: datetime.timedelta or False
92+
:param user_claims: Custom claims to include in this token. This data must
93+
be json serializable
9194
:param csrf: Whether to include a csrf double submit claim in this token
9295
(boolean)
9396
:param identity_claim_key: Which key should be used to store the identity
97+
:param user_claims_key: Which key should be used to store the user claims
9498
:return: Encoded refresh token
9599
"""
96100
token_data = {
97101
identity_claim_key: identity,
98102
'type': 'refresh',
99103
}
104+
105+
# Don't add extra data to the token if user_claims is empty.
106+
if user_claims:
107+
token_data[user_claims_key] = user_claims
108+
100109
if csrf:
101110
token_data['csrf'] = _create_csrf_token()
102111
return _encode_jwt(token_data, expires_delta, secret, algorithm,
@@ -129,8 +138,8 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
129138
if data['type'] == 'access':
130139
if 'fresh' not in data:
131140
raise JWTDecodeError("Missing claim: fresh")
132-
if user_claims_key not in data:
133-
data[user_claims_key] = {}
141+
if user_claims_key not in data:
142+
data[user_claims_key] = {}
134143
if csrf_value:
135144
if 'csrf' not in data:
136145
raise JWTDecodeError("Missing claim: csrf")

tests/test_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_default_configs(app):
6161
assert config.identity_claim_key == 'identity'
6262
assert config.user_claims_key == 'user_claims'
6363

64+
assert config.user_claims_in_refresh_token is False
65+
6466
assert config.json_encoder is app.json_encoder
6567

6668

@@ -100,6 +102,8 @@ def test_override_configs(app):
100102
app.config['JWT_IDENTITY_CLAIM'] = 'foo'
101103
app.config['JWT_USER_CLAIMS'] = 'bar'
102104

105+
app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True
106+
103107
class CustomJSONEncoder(JSONEncoder):
104108
pass
105109

@@ -148,6 +152,8 @@ class CustomJSONEncoder(JSONEncoder):
148152
assert config.identity_claim_key == 'foo'
149153
assert config.user_claims_key == 'bar'
150154

155+
assert config.user_claims_in_refresh_token is True
156+
151157
assert config.json_encoder is CustomJSONEncoder
152158

153159

tests/test_user_claims_loader.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from flask_jwt_extended import (
55
JWTManager, create_access_token, jwt_required, get_jwt_claims,
6-
decode_token
6+
decode_token, jwt_refresh_token_required, create_refresh_token
77
)
88
from tests.utils import get_jwt_manager, make_headers
99

@@ -19,6 +19,11 @@ def app():
1919
def get_claims():
2020
return jsonify(get_jwt_claims())
2121

22+
@app.route('/protected2', methods=['GET'])
23+
@jwt_refresh_token_required
24+
def get_refresh_claims():
25+
return jsonify(get_jwt_claims())
26+
2227
return app
2328

2429

@@ -99,3 +104,36 @@ def add_claims(identity):
99104
response = test_client.get('/protected', headers=make_headers(access_token))
100105
assert response.get_json() == {'foo': 'bar'}
101106
assert response.status_code == 200
107+
108+
109+
def test_user_claim_not_in_refresh_token(app):
110+
jwt = get_jwt_manager(app)
111+
112+
@jwt.user_claims_loader
113+
def add_claims(identity):
114+
return {'foo': 'bar'}
115+
116+
with app.test_request_context():
117+
refresh_token = create_refresh_token('username')
118+
119+
test_client = app.test_client()
120+
response = test_client.get('/protected2', headers=make_headers(refresh_token))
121+
assert response.get_json() == {}
122+
assert response.status_code == 200
123+
124+
125+
def test_user_claim_in_refresh_token(app):
126+
app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True
127+
jwt = get_jwt_manager(app)
128+
129+
@jwt.user_claims_loader
130+
def add_claims(identity):
131+
return {'foo': 'bar'}
132+
133+
with app.test_request_context():
134+
refresh_token = create_refresh_token('username')
135+
136+
test_client = app.test_client()
137+
response = test_client.get('/protected2', headers=make_headers(refresh_token))
138+
assert response.get_json() == {'foo': 'bar'}
139+
assert response.status_code == 200

0 commit comments

Comments
 (0)