Skip to content

Commit

Permalink
allow JWS header types both before GNAPv19 and after
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrodi committed Jun 5, 2024
1 parent 6f2e97d commit b82079d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/auth_server/models/gnap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SupportedAlgorithms,
SupportedHTTPMethods,
SupportedJWSType,
SupportedJWSTypeLegacy,
SymmetricJWK,
)

Expand Down Expand Up @@ -273,7 +274,7 @@ class GrantResponse(GnapBaseModel):
class GNAPJOSEHeader(JOSEHeader):
kid: str
alg: SupportedAlgorithms
typ: SupportedJWSType
typ: Union[SupportedJWSType, SupportedJWSTypeLegacy]
htm: SupportedHTTPMethods
# The HTTP URI used for this request, including all path and query components.
uri: str
Expand Down
7 changes: 6 additions & 1 deletion src/auth_server/models/jose.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ class JWKS(BaseModel):
keys: List[Union[ECJWK, RSAJWK, SymmetricJWK]]


class SupportedJWSType(str, Enum):
class SupportedJWSTypeLegacy(str, Enum):
JWS = "gnap-binding+jws"
JWSD = "gnap-binding+jwsd"


class SupportedJWSType(str, Enum):
JWS = "gnap-binding-jws"
JWSD = "gnap-binding-jwsd"


class JOSEHeader(BaseModel):
kid: Optional[str] = None
alg: SupportedAlgorithms
Expand Down
6 changes: 3 additions & 3 deletions src/auth_server/proof/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from auth_server.config import load_config
from auth_server.context import ContextRequest
from auth_server.models.gnap import GNAPJOSEHeader, Key
from auth_server.models.jose import JWK, SupportedAlgorithms, SupportedJWSType
from auth_server.models.jose import JWK, SupportedAlgorithms, SupportedJWSType, SupportedJWSTypeLegacy
from auth_server.time_utils import utc_now
from auth_server.utils import hash_with

Expand Down Expand Up @@ -90,7 +90,7 @@ async def check_jws_proof(
logger.error("Missing JWS header")
raise HTTPException(status_code=400, detail=f"Missing JWS header: {e}")

if jws_header.typ is not SupportedJWSType.JWS:
if jws_header.typ not in [SupportedJWSType.JWS, SupportedJWSTypeLegacy.JWS]:
raise HTTPException(status_code=400, detail=f"typ should be {SupportedJWSType.JWS}")

return await verify_gnap_jws(request=request, gnap_key=gnap_key, jws_header=jws_header, access_token=access_token)
Expand Down Expand Up @@ -144,7 +144,7 @@ async def check_jwsd_proof(
logger.error(f"Missing Detached JWS header: {e}")
raise HTTPException(status_code=400, detail=str(e))

if jws_header.typ is not SupportedJWSType.JWSD:
if jws_header.typ not in [SupportedJWSType.JWSD, SupportedJWSTypeLegacy.JWSD]:
raise HTTPException(status_code=400, detail=f"typ should be {SupportedJWSType.JWSD}")

return await verify_gnap_jws(request=request, gnap_key=gnap_key, jws_header=jws_header, access_token=access_token)
Expand Down
84 changes: 83 additions & 1 deletion src/auth_server/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ProofMethod,
StartInteractionMethod,
)
from auth_server.models.jose import ECJWK, SupportedAlgorithms, SupportedHTTPMethods, SupportedJWSType
from auth_server.models.jose import ECJWK, SupportedAlgorithms, SupportedHTTPMethods, SupportedJWSType, SupportedJWSTypeLegacy
from auth_server.models.status import Status
from auth_server.saml2 import AuthnInfo, NameID, SAMLAttributes, SessionInfo
from auth_server.testing import MongoTemporaryInstance
Expand Down Expand Up @@ -321,6 +321,41 @@ def test_transaction_jws(self):
claims = self._get_access_token_claims(access_token=access_token, client=self.client)
assert claims["auth_source"] == AuthSource.TEST

def test_transaction_jws_legacy_typ(self):
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)
req = GrantRequest(
client=Client(key=Key(proof=Proof(method=ProofMethod.JWS), jwk=client_jwk)),
access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])],
)
jws_header = {
"typ": SupportedJWSTypeLegacy.JWS,
"alg": SupportedAlgorithms.ES256.value,
"kid": self.client_jwk.key_id,
"htm": SupportedHTTPMethods.POST.value,
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}
_jws = jws.JWS(payload=req.json(exclude_unset=True))
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
)
data = _jws.serialize(compact=True)

client_header = {"Content-Type": "application/jose"}
response = self.client.post("/transaction", content=data, headers=client_header)

assert response.status_code == 200
assert "access_token" in response.json()
access_token = response.json()["access_token"]
assert AccessTokenFlags.BEARER.value in access_token["flags"]
assert access_token["value"] is not None
# Verify token and check claims
claims = self._get_access_token_claims(access_token=access_token, client=self.client)
assert claims["auth_source"] == AuthSource.TEST


def test_deserialize_bad_jws(self):
client_header = {"Content-Type": "application/jose"}
response = self.client.post("/transaction", content=b"bogus_jws", headers=client_header)
Expand Down Expand Up @@ -374,6 +409,53 @@ def test_transaction_jwsd(self):
claims = self._get_access_token_claims(access_token=access_token, client=self.client)
assert claims["auth_source"] == AuthSource.TEST

def test_transaction_jwsd_legacy_typ(self):
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)
req = GrantRequest(
client=Client(key=Key(proof=Proof(method=ProofMethod.JWSD), jwk=client_jwk)),
access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])],
)
jws_header = {
"typ": SupportedJWSTypeLegacy.JWSD,
"alg": SupportedAlgorithms.ES256.value,
"kid": self.client_jwk.key_id,
"htm": SupportedHTTPMethods.POST.value,
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}

payload = req.model_dump_json(exclude_unset=True)

# create a hash of payload to send in payload place
payload_digest = hash_with(SHA256(), payload.encode())
payload_hash = base64url_encode(payload_digest)

# create detached jws
_jws = jws.JWS(payload=payload)
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
)
data = _jws.serialize(compact=True)

# Remove payload from serialized jws
header, _, signature = data.split(".")
client_header = {"Detached-JWS": f"{header}.{payload_hash}.{signature}"}

response = self.client.post(
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
)

assert response.status_code == 200
assert "access_token" in response.json()
access_token = response.json()["access_token"]
assert AccessTokenFlags.BEARER.value in access_token["flags"]
assert access_token["value"] is not None
# Verify token and check claims
claims = self._get_access_token_claims(access_token=access_token, client=self.client)
assert claims["auth_source"] == AuthSource.TEST

@mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
def test_mdq_flow(self, mock_mdq):
self.config["auth_flows"] = json.dumps(["TestFlow", "MDQFlow"])
Expand Down

0 comments on commit b82079d

Please sign in to comment.