diff --git a/src/h_vialib/secure/encryption.py b/src/h_vialib/secure/encryption.py index fcaaa90..c80efb4 100644 --- a/src/h_vialib/secure/encryption.py +++ b/src/h_vialib/secure/encryption.py @@ -1,6 +1,7 @@ import json from joserfc import jwe +from joserfc.jwk import OctKey class Encryption: @@ -8,16 +9,16 @@ class Encryption: JWE_ENCRYPTION = "A128CBC-HS256" def __init__(self, secret: bytes): - self._secret = secret.ljust(32)[:32] + self._key = OctKey.import_key(secret.ljust(32)[:32]) def encrypt_dict(self, payload: dict) -> str: """Encrypt a dictionary as a JWE.""" protected = {"alg": self.JWE_ALGORITHM, "enc": self.JWE_ENCRYPTION} return jwe.encrypt_compact( - protected, json.dumps(payload).encode("utf-8"), self._secret + protected, json.dumps(payload).encode("utf-8"), self._key ) def decrypt_dict(self, payload: str) -> dict: """Decrypts payloads created by `encrypt_dict`.""" - data = jwe.decrypt_compact(payload, self._secret).plaintext + data = jwe.decrypt_compact(payload, self._key).plaintext return json.loads(data) diff --git a/src/h_vialib/secure/token.py b/src/h_vialib/secure/token.py index 94f5678..22287f2 100644 --- a/src/h_vialib/secure/token.py +++ b/src/h_vialib/secure/token.py @@ -2,6 +2,7 @@ from joserfc import jwt from joserfc.errors import JoseError +from joserfc.jwk import OctKey from h_vialib.exceptions import InvalidToken, MissingToken from h_vialib.secure.expiry import as_expires @@ -17,7 +18,7 @@ def __init__(self, secret): :param secret: The secret to sign and check tokens with """ - self._secret = secret + self._key = OctKey.import_key(secret) def create(self, payload=None, expires=None, max_age=None) -> str: """Create a secure token. @@ -30,7 +31,7 @@ def create(self, payload=None, expires=None, max_age=None) -> str: :raise ValueError: if neither expires nor max_age is specified """ payload["exp"] = int(as_expires(expires, max_age).timestamp()) - return jwt.encode({"alg": self.TOKEN_ALGORITHM}, payload, self._secret) + return jwt.encode({"alg": self.TOKEN_ALGORITHM}, payload, self._key) def verify(self, token: str) -> dict: """Decode a token and check for validity. @@ -45,7 +46,7 @@ def verify(self, token: str) -> dict: raise MissingToken("Missing secure token") try: - claims = jwt.decode(token, self._secret).claims + claims = jwt.decode(token, self._key).claims jwt.JWTClaimsRegistry().validate(claims) except JoseError as err: raise InvalidToken() from err diff --git a/tests/unit/h_vialib/secure/encryption_test.py b/tests/unit/h_vialib/secure/encryption_test.py index c02a41b..482b657 100644 --- a/tests/unit/h_vialib/secure/encryption_test.py +++ b/tests/unit/h_vialib/secure/encryption_test.py @@ -4,54 +4,71 @@ class TestEncryption: - def test_encrypt_dict_round_trip(self, encryption): + """Tests for h_vialib.secure.encryption that *do not* patch joserfc.""" + + def test_encrypt_dict_decrypt_dict_round_trip(self, encryption): payload_dict = {"some": "data"} encrypted = encryption.encrypt_dict(payload_dict) assert encryption.decrypt_dict(encrypted) == payload_dict - def test_encrypt_dict(self, encryption, secret, json, jwe): + def test_decrypt_dict_hardcoded(self, encryption): + # Copied from the output of decrypt_dict. + # Useful to check backwards compatibility when updating the crypto backend + encrypted = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..q7UXaHtenyFA5VD3QhrxXA.gkAmUrzmW5UFpuF_tZLmcUzUfS9FuLAiV_xqRJBVJ3Y.U42rUD65NVjH-SoFfeDoOw" + + plain_text_dict = encryption.decrypt_dict(encrypted) + + assert plain_text_dict == {"some": "data"} + + +class TestEncryptionPatched: + """Tests for h_vialib.secure.encryption that patch joserfc.""" + + def test_encrypt_dict(self, encryption, secret, OctKey, json, jwe): payload_dict = {"some": "data"} encrypted = encryption.encrypt_dict(payload_dict) + OctKey.import_key.assert_called_once_with(secret.ljust(32)) json.dumps.assert_called_with(payload_dict) jwe.encrypt_compact.assert_called_once_with( {"alg": encryption.JWE_ALGORITHM, "enc": encryption.JWE_ENCRYPTION}, json.dumps.return_value.encode.return_value, - secret.ljust(32), + OctKey.import_key.return_value, ) assert encrypted == jwe.encrypt_compact.return_value - def test_decrypt_dict(self, encryption, secret, json, jwe): + def test_decrypt_dict(self, encryption, secret, json, jwe, OctKey): plain_text_dict = encryption.decrypt_dict("payload") - jwe.decrypt_compact.assert_called_once_with("payload", secret.ljust(32)) + OctKey.import_key.assert_called_once_with(secret.ljust(32)) + jwe.decrypt_compact.assert_called_once_with( + "payload", OctKey.import_key.return_value + ) json.loads.assert_called_once_with(jwe.decrypt_compact.return_value.plaintext) assert plain_text_dict == json.loads.return_value - def test_decrypt_dict_hardcoded(self, encryption): - # Copied from the output of decrypt_dict. - # Useful to check backwards compatibility when updating the crypto backend - encrypted = "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..q7UXaHtenyFA5VD3QhrxXA.gkAmUrzmW5UFpuF_tZLmcUzUfS9FuLAiV_xqRJBVJ3Y.U42rUD65NVjH-SoFfeDoOw" + @pytest.fixture(autouse=True) + def json(self, patch): + return patch("h_vialib.secure.encryption.json") - plain_text_dict = encryption.decrypt_dict(encrypted) + @pytest.fixture(autouse=True) + def jwe(self, patch): + return patch("h_vialib.secure.encryption.jwe") - assert plain_text_dict == {"some": "data"} + @pytest.fixture(autouse=True) + def OctKey(self, patch): + return patch("h_vialib.secure.encryption.OctKey") - @pytest.fixture - def secret(self): - return b"VERY SECRET" - @pytest.fixture - def encryption(self, secret): - return Encryption(secret) +@pytest.fixture +def secret(): + return b"VERY SECRET" - @pytest.fixture - def json(self, patch): - return patch("h_vialib.secure.encryption.json") - @pytest.fixture - def jwe(self, patch): - return patch("h_vialib.secure.encryption.jwe") +@pytest.fixture +def encryption(secret): + """Return the h_vialib.encryption.Encryption object to be tested.""" + return Encryption(secret) diff --git a/tests/unit/h_vialib/secure/token_test.py b/tests/unit/h_vialib/secure/token_test.py index e4c24af..59eceed 100644 --- a/tests/unit/h_vialib/secure/token_test.py +++ b/tests/unit/h_vialib/secure/token_test.py @@ -5,13 +5,16 @@ from h_matchers import Any from joserfc import jwt from joserfc.errors import JoseError +from joserfc.jwk import OctKey from h_vialib.exceptions import InvalidToken, MissingToken from h_vialib.secure.token import SecureToken +key = OctKey.import_key("a_very_secret_secret") + def decode_token(token_string): - return jwt.decode(token_string, "a_very_secret_secret").claims + return jwt.decode(token_string, key).claims # This were generated by python-jose but are different than the ones generated by joserfc