From cb0b1a38e5e878202230d27b0ff3545e19f0cb27 Mon Sep 17 00:00:00 2001 From: Sean Hammond Date: Wed, 12 Jun 2024 12:41:54 +0100 Subject: [PATCH] Avoid a joserfc deprecation error You're not supposed to use bytes directly as the key. You're not supposed to do this: >>> from joserfc import jwt >>> >>> encoded = jwt.encode({"alg": "HS256"}, {"k": "value"}, b"secret_key") >>> token = jwt.decode(encoded, b"secret_key") As of joserfc 0.11.0 this raises a `DeprecationWarning`: `"Please use a Key object instead of bytes or string."` https://github.com/authlib/joserfc/blob/bf3eaf55a689f2170191ed3562d4815c4b0ebdd0/src/joserfc/jwk.py#L82-L86 Instead, you're supposed to do this: >>> from joserfc import jwt >>> from joserfc.jwk import OctKey >>> >>> key = OctKey.import_key(b"secret_key") >>> >>> encoded = jwt.encode({"alg": "HS256"}, {"k": "value"}, key) >>> token = jwt.decode(encoded, key) I also had to refactor `encryption_test.py` to fix a subtle issue with its `patch()` fixtures and pytest fixture execution order: the `encryption()` fixture (which instantiates the `h_vialib.secure.encryption.Encryption` object that is to be tested) was being executed **before** the various `patch()` fixtures to patch `joserfc` and `json`. This means that if `Encryption.__init__()` uses `joserfc` then it'll use the real `joserfc` and then, after `__init__()` has executed, the `patch()` fixtures execute and `joserfc` is replaced with a mock after it has already been used. Fix this by turning the patch fixtures into `autouse=True`. This fixes the issue because `autouse=True` fixtures are always executed before other fixtures. `patch()` fixtures should always be `autouse=True`. The `TestEncryption` class actually has some tests that are supposed to use a patched `joserfc` and some tests that are supposed to use the real `joserfc`. This isn't possible with an `autouse=True` patch fixture because the patch fixture will now be used for every test in the class. This forces me to split the patched- and not-patched tests into two separate classes. This is a good thing: it's confusing to have patched and not-patched tests in the same class, easy for the reader to miss that some tests are patched and some aren't. It's better to separate these two distinct types of test into two clearly labelled classes. --- src/h_vialib/secure/encryption.py | 7 ++- src/h_vialib/secure/token.py | 7 ++- tests/unit/h_vialib/secure/encryption_test.py | 63 ++++++++++++------- tests/unit/h_vialib/secure/token_test.py | 5 +- 4 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/h_vialib/secure/encryption.py b/src/h_vialib/secure/encryption.py index fcaaa90..c80efb4 100644 --- a/src/h_vialib/secure/encryption.py +++ b/src/h_vialib/secure/encryption.py @@ -1,6 +1,7 @@ import json from joserfc import jwe +from joserfc.jwk import OctKey class Encryption: @@ -8,16 +9,16 @@ class Encryption: JWE_ENCRYPTION = "A128CBC-HS256" def __init__(self, secret: bytes): - self._secret = secret.ljust(32)[:32] + self._key = OctKey.import_key(secret.ljust(32)[:32]) def encrypt_dict(self, payload: dict) -> str: """Encrypt a dictionary as a JWE.""" protected = {"alg": self.JWE_ALGORITHM, "enc": self.JWE_ENCRYPTION} return jwe.encrypt_compact( - protected, json.dumps(payload).encode("utf-8"), self._secret + protected, json.dumps(payload).encode("utf-8"), self._key ) def decrypt_dict(self, payload: str) -> dict: """Decrypts payloads created by `encrypt_dict`.""" - data = jwe.decrypt_compact(payload, self._secret).plaintext + data = jwe.decrypt_compact(payload, self._key).plaintext return json.loads(data) diff --git a/src/h_vialib/secure/token.py b/src/h_vialib/secure/token.py index 94f5678..22287f2 100644 --- a/src/h_vialib/secure/token.py +++ b/src/h_vialib/secure/token.py @@ -2,6 +2,7 @@ from joserfc import jwt from joserfc.errors import JoseError +from joserfc.jwk import OctKey from h_vialib.exceptions import InvalidToken, MissingToken from h_vialib.secure.expiry import as_expires @@ -17,7 +18,7 @@ def __init__(self, secret): :param secret: The secret to sign and check tokens with """ - self._secret = secret + self._key = OctKey.import_key(secret) def create(self, payload=None, expires=None, max_age=None) -> str: """Create a secure token. @@ -30,7 +31,7 @@ def create(self, payload=None, expires=None, max_age=None) -> str: :raise ValueError: if neither expires nor max_age is specified """ payload["exp"] = int(as_expires(expires, max_age).timestamp()) - return jwt.encode({"alg": self.TOKEN_ALGORITHM}, payload, self._secret) + return jwt.encode({"alg": self.TOKEN_ALGORITHM}, payload, self._key) def verify(self, token: str) -> dict: """Decode a token and check for validity. @@ -45,7 +46,7 @@ def verify(self, token: str) -> dict: raise MissingToken("Missing secure token") try: - claims = jwt.decode(token, self._secret).claims + claims = jwt.decode(token, self._key).claims jwt.JWTClaimsRegistry().validate(claims) except JoseError as err: raise InvalidToken() from err diff --git a/tests/unit/h_vialib/secure/encryption_test.py b/tests/unit/h_vialib/secure/encryption_test.py index c02a41b..482b657 100644 --- a/tests/unit/h_vialib/secure/encryption_test.py +++ b/tests/unit/h_vialib/secure/encryption_test.py @@ -4,54 +4,71 @@ class TestEncryption: - def test_encrypt_dict_round_trip(self, encryption): + """Tests for h_vialib.secure.encryption that *do not* patch joserfc.""" + + def test_encrypt_dict_decrypt_dict_round_trip(self, encryption): payload_dict = {"some": "data"} encrypted = encryption.encrypt_dict(payload_dict) assert encryption.decrypt_dict(encrypted) == payload_dict - def test_encrypt_dict(self, encryption, secret, json, jwe): + def test_decrypt_dict_hardcoded(self, encryption): + # Copied from the output of decrypt_dict. + # Useful to check backwards compatibility when updating the crypto backend + encrypted = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..q7UXaHtenyFA5VD3QhrxXA.gkAmUrzmW5UFpuF_tZLmcUzUfS9FuLAiV_xqRJBVJ3Y.U42rUD65NVjH-SoFfeDoOw" + + plain_text_dict = encryption.decrypt_dict(encrypted) + + assert plain_text_dict == {"some": "data"} + + +class TestEncryptionPatched: + """Tests for h_vialib.secure.encryption that patch joserfc.""" + + def test_encrypt_dict(self, encryption, secret, OctKey, json, jwe): payload_dict = {"some": "data"} encrypted = encryption.encrypt_dict(payload_dict) + OctKey.import_key.assert_called_once_with(secret.ljust(32)) json.dumps.assert_called_with(payload_dict) jwe.encrypt_compact.assert_called_once_with( {"alg": encryption.JWE_ALGORITHM, "enc": encryption.JWE_ENCRYPTION}, json.dumps.return_value.encode.return_value, - secret.ljust(32), + OctKey.import_key.return_value, ) assert encrypted == jwe.encrypt_compact.return_value - def test_decrypt_dict(self, encryption, secret, json, jwe): + def test_decrypt_dict(self, encryption, secret, json, jwe, OctKey): plain_text_dict = encryption.decrypt_dict("payload") - jwe.decrypt_compact.assert_called_once_with("payload", secret.ljust(32)) + OctKey.import_key.assert_called_once_with(secret.ljust(32)) + jwe.decrypt_compact.assert_called_once_with( + "payload", OctKey.import_key.return_value + ) json.loads.assert_called_once_with(jwe.decrypt_compact.return_value.plaintext) assert plain_text_dict == json.loads.return_value - def test_decrypt_dict_hardcoded(self, encryption): - # Copied from the output of decrypt_dict. - # Useful to check backwards compatibility when updating the crypto backend - encrypted = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..q7UXaHtenyFA5VD3QhrxXA.gkAmUrzmW5UFpuF_tZLmcUzUfS9FuLAiV_xqRJBVJ3Y.U42rUD65NVjH-SoFfeDoOw" + @pytest.fixture(autouse=True) + def json(self, patch): + return patch("h_vialib.secure.encryption.json") - plain_text_dict = encryption.decrypt_dict(encrypted) + @pytest.fixture(autouse=True) + def jwe(self, patch): + return patch("h_vialib.secure.encryption.jwe") - assert plain_text_dict == {"some": "data"} + @pytest.fixture(autouse=True) + def OctKey(self, patch): + return patch("h_vialib.secure.encryption.OctKey") - @pytest.fixture - def secret(self): - return b"VERY SECRET" - @pytest.fixture - def encryption(self, secret): - return Encryption(secret) +@pytest.fixture +def secret(): + return b"VERY SECRET" - @pytest.fixture - def json(self, patch): - return patch("h_vialib.secure.encryption.json") - @pytest.fixture - def jwe(self, patch): - return patch("h_vialib.secure.encryption.jwe") +@pytest.fixture +def encryption(secret): + """Return the h_vialib.encryption.Encryption object to be tested.""" + return Encryption(secret) diff --git a/tests/unit/h_vialib/secure/token_test.py b/tests/unit/h_vialib/secure/token_test.py index e4c24af..59eceed 100644 --- a/tests/unit/h_vialib/secure/token_test.py +++ b/tests/unit/h_vialib/secure/token_test.py @@ -5,13 +5,16 @@ from h_matchers import Any from joserfc import jwt from joserfc.errors import JoseError +from joserfc.jwk import OctKey from h_vialib.exceptions import InvalidToken, MissingToken from h_vialib.secure.token import SecureToken +key = OctKey.import_key("a_very_secret_secret") + def decode_token(token_string): - return jwt.decode(token_string, "a_very_secret_secret").claims + return jwt.decode(token_string, key).claims # This were generated by python-jose but are different than the ones generated by joserfc