diff --git a/src/lib.rs b/src/lib.rs index 44325ff..a77559d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,7 @@ pub(crate) struct SDJWTCommon { hash_to_decoded_disclosure: HashMap, hash_to_disclosure: HashMap, input_disclosures: Vec, + sign_alg: Option, } #[derive(Default, Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] @@ -141,6 +142,7 @@ impl SDJWTCommon { length: parts.len(), msg: format!("Invalid SD-JWT: {}", sd_jwt_with_disclosures), })?; + self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt); self.unverified_input_key_binding_jwt = Some( parts .next_back() @@ -179,12 +181,14 @@ impl SDJWTCommon { self.input_disclosures = parsed_sd_jwt_json.disclosures; self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(&parsed_sd_jwt_json.payload)?); - self.unverified_sd_jwt = Some(format!( + let sd_jwt = format!( "{}.{}.{}", parsed_sd_jwt_json.protected, parsed_sd_jwt_json.payload, parsed_sd_jwt_json.signature - )); + ); + self.unverified_sd_jwt = Some(sd_jwt.clone()); + self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt); Ok(()) } @@ -198,6 +202,25 @@ impl SDJWTCommon { } } } + /// Decodes a header jwt string and extracts the "alg" field from the JSON object. + /// # Arguments + /// * `sd_jwt` - jwt format string. + /// # Returns + /// * `Option` - The result containing the algorithm String e.g ES256 or on failure None. + fn decode_header_and_get_sign_algorithm(sd_jwt: &str) -> Option { + let parts: Vec<&str> = sd_jwt.split('.').collect(); + if parts.len() < 2 { + return None; + } + let jwt_header = parts[0]; + let decoded = base64url_decode(jwt_header).ok()?; + let decoded_str = std::str::from_utf8(&decoded).ok()?; + let json_sign_alg: Value = serde_json::from_str(decoded_str).ok()?; + let sign_alg = json_sign_alg.get("alg") + .and_then(Value::as_str) + .map(String::from); + sign_alg + } } #[cfg(test)] diff --git a/src/verifier.rs b/src/verifier.rs index 7b27752..21444d5 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -63,14 +63,15 @@ impl SDJWTVerifier { verifier.sd_jwt_engine.parse_sd_jwt(sd_jwt_presentation)?; verifier.sd_jwt_engine.create_hash_mappings()?; - verifier.verify_sd_jwt(Some(DEFAULT_SIGNING_ALG.to_owned()))?; + let sign_alg = verifier.sd_jwt_engine.sign_alg.clone(); + verifier.verify_sd_jwt(sign_alg.clone())?; verifier.verified_claims = verifier.extract_sd_claims()?; if let (Some(expected_aud), Some(expected_nonce)) = (&expected_aud, &expected_nonce) { verifier.verify_key_binding_jwt( expected_aud.to_owned(), expected_nonce.to_owned(), - Some(DEFAULT_SIGNING_ALG), + sign_alg.as_deref(), )?; } else if expected_aud.is_some() || expected_nonce.is_some() { return Err(Error::InvalidInput( @@ -99,11 +100,15 @@ impl SDJWTVerifier { .as_str() .ok_or(Error::ConversionError("str".to_string()))?; let issuer_public_key = (self.cb_get_issuer_key)(unverified_issuer, &parsed_header_sd_jwt); - + let algorithm: Algorithm = match sign_alg { + Some(alg_str) => Algorithm::from_str(&alg_str) + .map_err(|e| Error::DeserializationError(e.to_string()))?, + None => Algorithm::ES256, // Default or handle as needed + }; let claims = jsonwebtoken::decode( sd_jwt, &issuer_public_key, - &Validation::new(Algorithm::ES256), + &Validation::new(algorithm), ) .map_err(|e| Error::DeserializationError(format!("Cannot decode jwt: {}", e)))? .claims; @@ -376,6 +381,8 @@ mod tests { const PRIVATE_ISSUER_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgUr2bNKuBPOrAaxsR\nnbSH6hIhmNTxSGXshDSUD1a1y7ihRANCAARvbx3gzBkyPDz7TQIbjF+ef1IsxUwz\nX1KWpmlVv+421F7+c1sLqGk4HUuoVeN8iOoAcE547pJhUEJyf5Asc6pP\n-----END PRIVATE KEY-----\n"; const PUBLIC_ISSUER_PEM: &str = "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEb28d4MwZMjw8+00CG4xfnn9SLMVM\nM19SlqZpVb/uNtRe/nNbC6hpOB1LqFXjfIjqAHBOeO6SYVBCcn+QLHOqTw==\n-----END PUBLIC KEY-----\n"; + const PRIVATE_ISSUER_ED25519_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMFECAQEwBQYDK2VwBCIEIF93k6rxZ8W38cm0rOwfGdH+YY3k10hP+7gd0falPLg0\ngSEAdW31QyWzfed4EPcw1rYuUa1QU+fXEL0HhdAfYZRkihc=\n-----END PRIVATE KEY-----\n"; + const PUBLIC_ISSUER_ED25519_PEM: &str = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAdW31QyWzfed4EPcw1rYuUa1QU+fXEL0HhdAfYZRkihc=\n-----END PUBLIC KEY-----\n"; #[test] fn verify_full_presentation() { @@ -662,4 +669,56 @@ mod tests { assert_eq!(verified_claims, expected_verified_claims); } + + #[test] + fn verify_full_presentation_to_allow_other_algorithms_json_format() { + + let user_claims = json!({ + "sub": "6c5c0a49-b589-431d-bae7-219122a9ec2c", + "iss": "https://example.com/issuer", + "iat": 1683000000, + "exp": 1883000000, + "address": { + "street_address": "Schulstr. 12", + "locality": "Schulpforta", + "region": "Sachsen-Anhalt", + "country": "DE" + } + }); + let private_issuer_bytes = PRIVATE_ISSUER_ED25519_PEM.as_bytes(); + let issuer_key = EncodingKey::from_ed_pem(private_issuer_bytes).unwrap(); + let sd_jwt = SDJWTIssuer::new(issuer_key, Some("EdDSA".to_string())).issue_sd_jwt( + user_claims.clone(), + ClaimsForSelectiveDisclosureStrategy::AllLevels, + None, + false, + SDJWTSerializationFormat::JSON, // Changed to Json format + ) + .unwrap(); + + let presentation = SDJWTHolder::new(sd_jwt.clone(), SDJWTSerializationFormat::JSON) // Changed to Json format + .unwrap() + .create_presentation( + user_claims.as_object().unwrap().clone(), + None, + None, + None, + None + ) + .unwrap(); + assert_eq!(sd_jwt, presentation); + let verified_claims = SDJWTVerifier::new( + presentation, + Box::new(|_, _| { + let public_issuer_bytes = PUBLIC_ISSUER_ED25519_PEM.as_bytes(); + DecodingKey::from_ed_pem(public_issuer_bytes).unwrap() + }), + None, + None, + SDJWTSerializationFormat::JSON, // Changed to Json format + ) + .unwrap() + .verified_claims; + assert_eq!(user_claims, verified_claims); + } }