diff --git a/src/claims.rs b/src/claims.rs index a21fbc90..d77b4829 100644 --- a/src/claims.rs +++ b/src/claims.rs @@ -2,7 +2,9 @@ use std::collections::BTreeMap; -use serde::{Deserialize, Serialize}; +use serde::de::{value, Error, SeqAccess, Visitor}; +use serde::ser::SerializeSeq; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// Generic [JWT claims](https://tools.ietf.org/html/rfc7519#page-8) with /// defined fields for registered and private claims. @@ -36,7 +38,7 @@ pub struct RegisteredClaims { pub subject: Option, #[serde(rename = "aud", skip_serializing_if = "Option::is_none")] - pub audience: Option, + pub audience: Option, #[serde(rename = "exp", skip_serializing_if = "Option::is_none")] pub expiration: Option, @@ -51,6 +53,76 @@ pub struct RegisteredClaims { pub json_web_token_id: Option, } +/// Struct to handle the `aud` field because the JWT spec says that +/// it can be either a string or an array of strings. +/// [Audience Claim Specificatgion](https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3). +#[derive(Clone, Debug, Default, PartialEq)] +pub struct StringOrVec { + one: Option, + multi: Option>, +} + +struct StringOrVecVisitor; + +impl<'de> Visitor<'de> for StringOrVecVisitor { + type Value = StringOrVec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string or an array of strings") + } + + fn visit_str(self, value: &str) -> Result + where + E: Error, + { + Ok(StringOrVec { + one: Some(value.to_string()), + multi: None, + }) + } + + fn visit_seq(self, seq: S) -> Result + where + S: SeqAccess<'de>, + { + match Deserialize::deserialize(value::SeqAccessDeserializer::new(seq)) { + Ok(r) => Ok(StringOrVec { + one: None, + multi: Some(r), + }), + Err(e) => Err(e), + } + } +} + +impl<'de> Deserialize<'de> for StringOrVec { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(StringOrVecVisitor) + } +} + +impl Serialize for StringOrVec { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if let Some(o) = &self.one { + serializer.serialize_str(&o) + } else if let Some(multi) = &self.multi { + let mut seq = serializer.serialize_seq(Some(multi.len()))?; + for e in multi { + seq.serialize_element(&e)?; + } + seq.end() + } else { + serializer.serialize_none() + } + } +} + #[cfg(test)] mod tests { use crate::claims::Claims; @@ -89,4 +161,39 @@ mod tests { assert_eq!(claims, Claims::from_base64(&*enc)?); Ok(()) } + + #[test] + fn aud_single() -> Result<(), Error> { + // {"iss": "mikkyang.com", "exp": 1302319100, "custom_claim": true, "aud": "test", "alg": "HS256" } + let payload = "eyJpc3MiOiJtaWtreWFuZy5jb20iLCJleHAiOjEzMDIzMTkxMDAsImN1c3RvbV9jbGFpbSI6dHJ1ZSwiYXVkIjoidGVzdCIsImFsZyI6IkhTMjU2In0"; + + let claims = Claims::from_base64(payload)?; + + assert_ne!(claims.registered.audience, None); + + let aud = &claims.registered.audience.unwrap(); + + assert_eq!(aud.one, Some("test".to_string())); + assert_eq!(aud.multi, None); + + Ok(()) + } + + #[test] + fn aud_multi() -> Result<(), Error> { + // {"iss": "mikkyang.com", "exp": 1302319100, "custom_claim": true, "aud": ["test1", "test2"], "alg": "HS256" } + let payload = "eyJpc3MiOiJtaWtreWFuZy5jb20iLCJleHAiOjEzMDIzMTkxMDAsImN1c3RvbV9jbGFpbSI6dHJ1ZSwiYXVkIjpbInRlc3QxIiwidGVzdDIiXSwiYWxnIjoiSFMyNTYifQ"; + + let claims = Claims::from_base64(payload)?; + + assert_ne!(claims.registered.audience, None); + + let aud = &claims.registered.audience.unwrap(); + + assert_eq!(aud.one, None); + assert_eq!(aud.multi.as_ref().unwrap().len(), 2); + assert_eq!(aud.multi.as_ref().unwrap()[0], "test1".to_string()); + assert_eq!(aud.multi.as_ref().unwrap()[1], "test2".to_string()); + Ok(()) + } }