Skip to content

Commit

Permalink
API fixes: Rework SD-JWT Claims Strategy and add enum for SD-JWT Seri…
Browse files Browse the repository at this point in the history
…alization Format (#21)

* Added the SDJWTSerializationFormat enum
* Renamed SDJWTClaimsStrategy -> ClaimsForSelectiveDisclosureStrategy
* Removed the 'leak' function call
* Added unit tests for the SDJWTClaimsStrategy::next_level method
* Code cleanup.


Signed-off-by: Alexander Sukhachev <[email protected]>
Signed-off-by: Sergey Minaev <[email protected]>
Co-authored-by: Sergey Minaev <[email protected]>
  • Loading branch information
alexsdsr and jovfer authored Jan 22, 2024
1 parent deddfcf commit ed88931
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 176 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Proposals about API improvements are highly appreciated.
```rust
fn demo() {
let mut issuer = SDJWTIssuer::new(issuer_key, None);
let sd_jwt = issuer.issue_sd_jwt(claims, SDJWTClaimsStrategy::Full, holder_key, add_decoy, "compact".to_owned()).unwrap();
let sd_jwt = issuer.issue_sd_jwt(claims, ClaimsForSelectiveDisclosureStrategy::AllLevels, holder_key, add_decoy, "compact".to_owned()).unwrap();

let mut holder = SDJWTHolder::new(sd_jwt, "compact".to_owned()).unwrap();
let presentation = holder.create_presentation(claims_to_disclosure, None, None, None, None).unwrap();
Expand Down
40 changes: 16 additions & 24 deletions src/holder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{error, SDJWTJson};
use crate::{error, SDJWTJson, SDJWTSerializationFormat};
use error::{Error, Result};
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use serde_json::{Map, Value};
Expand Down Expand Up @@ -29,7 +29,7 @@ impl SDJWTHolder {
///
/// # Arguments
/// * `sd_jwt_with_disclosures` - SD JWT with disclosures in the format specified by `serialization_format`
/// * `serialization_format` - Serialization format of the SD JWT. Supported values are `compact` and `json`
/// * `serialization_format` - Serialization format of the SD JWT, see [SDJWTSerializationFormat].
///
/// # Returns
/// * `SDJWTHolder` - Instance of SDJWTHolder
Expand All @@ -38,18 +38,10 @@ impl SDJWTHolder {
/// * `InvalidInput` - If the serialization format is not supported
/// * `InvalidState` - If the SD JWT data is not valid
/// * `DeserializationError` - If the SD JWT serialization is not valid
pub fn new(sd_jwt_with_disclosures: String, serialization_format: String) -> Result<Self> {
let serialization_format = serialization_format.to_lowercase();
if serialization_format != "compact" && serialization_format != "json" {
return Err(Error::InvalidInput(format!(
"Serialization format \"{}\" is not supported",
serialization_format
)));
}

pub fn new(sd_jwt_with_disclosures: String, serialization_format: SDJWTSerializationFormat) -> Result<Self> {
let mut holder = SDJWTHolder {
sd_jwt_engine: SDJWTCommon {
serialization_format: serialization_format.clone(),
serialization_format,
..Default::default()
},
hs_disclosures: Vec::new(),
Expand Down Expand Up @@ -119,7 +111,7 @@ impl SDJWTHolder {
}
}

let sd_jwt_presentation = if self.sd_jwt_engine.serialization_format == "compact" {
let sd_jwt_presentation = if self.sd_jwt_engine.serialization_format == SDJWTSerializationFormat::Compact {
let mut combined: Vec<&str> = Vec::with_capacity(self.hs_disclosures.len() + 2);
combined.push(&self.serialized_sd_jwt);
combined.extend(self.hs_disclosures.iter().map(|s| s.as_str()));
Expand Down Expand Up @@ -349,8 +341,8 @@ impl SDJWTHolder {

#[cfg(test)]
mod tests {
use crate::issuer::SDJWTClaimsStrategy;
use crate::{SDJWTHolder, SDJWTIssuer, COMBINED_SERIALIZATION_FORMAT_SEPARATOR};
use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
use crate::{SDJWTHolder, SDJWTIssuer, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, SDJWTSerializationFormat};
use jsonwebtoken::EncodingKey;
use serde_json::{json, Map, Value};
use std::collections::HashSet;
Expand All @@ -375,15 +367,15 @@ mod tests {
let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims.clone(),
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
let presentation = SDJWTHolder::new(
sd_jwt.clone(),
"compact".to_ascii_lowercase(),
SDJWTSerializationFormat::Compact,
)
.unwrap()
.create_presentation(
Expand Down Expand Up @@ -415,16 +407,16 @@ mod tests {

let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims.clone(),
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
let issued = sd_jwt.clone();
user_claims["address"] = Value::Object(Map::new());
let presentation =
SDJWTHolder::new(sd_jwt, "compact".to_ascii_lowercase())
SDJWTHolder::new(sd_jwt, SDJWTSerializationFormat::Compact)
.unwrap()
.create_presentation(
user_claims.as_object().unwrap().clone(),
Expand Down Expand Up @@ -472,7 +464,7 @@ mod tests {
]
}
);
let strategy = SDJWTClaimsStrategy::Partial(vec![
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"$.name",
"$.addresses[1]",
"$.addresses[1].country",
Expand All @@ -486,7 +478,7 @@ mod tests {
strategy,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
// Choose what to reveal
Expand All @@ -496,7 +488,7 @@ mod tests {
let issued = sd_jwt.clone();
println!("{}", issued);
let presentation =
SDJWTHolder::new(sd_jwt, "compact".to_ascii_lowercase())
SDJWTHolder::new(sd_jwt, SDJWTSerializationFormat::Compact)
.unwrap()
.create_presentation(
user_claims.as_object().unwrap().clone(),
Expand Down
128 changes: 80 additions & 48 deletions src/issuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::utils::{base64_hash, generate_salt};
use crate::{
SDJWTCommon, CNF_KEY, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG,
DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, JWK_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX,
SDJWTSerializationFormat,
};

pub struct SDJWTIssuer {
Expand All @@ -36,28 +37,30 @@ pub struct SDJWTIssuer {
pub serialized_sd_jwt: String,
}

/// SDJWTClaimsStrategy is used to determine which claims can be selectively disclosed later by the holder.
pub enum SDJWTClaimsStrategy<'a> {
/// No claims can be selectively disclosed. Full disclosure.
No,
/// Top-level claims can be selectively disclosed, nested objects are fully disclosed.
Flat,
/// ClaimsForSelectiveDisclosureStrategy is used to determine which claims can be selectively disclosed later by the holder.
#[derive(PartialEq, Debug)]
pub enum ClaimsForSelectiveDisclosureStrategy<'a> {
/// No claims can be selectively disclosed, so all claims are always disclosed in presentations generated by the holder.
NoSDClaims,
/// Top-level claims can be selectively disclosed, nested objects are fully disclosed, if a parent claim is disclosed.
TopLevel,
/// All claims can be selectively disclosed (recursively including nested objects).
Full,
/// Claims can be selectively disclosed based on the provided JSONPaths. Other claims are fully disclosed.
AllLevels,
/// Claims can be selectively disclosed based on the provided JSONPaths.
/// Other claims are always disclosed in presentation generated by the holder.
/// # Examples
/// ```
/// use sd_jwt_rs::issuer::SDJWTClaimsStrategy;
/// use sd_jwt_rs::issuer::ClaimsForSelectiveDisclosureStrategy;
///
/// let strategy = SDJWTClaimsStrategy::Partial(vec!["$.address", "$.address.street_address"]);
/// let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec!["$.address", "$.address.street_address"]);
/// ```
Partial(Vec<&'a str>),
Custom(Vec<&'a str>),
}

impl<'a> SDJWTClaimsStrategy<'a> {
impl<'a> ClaimsForSelectiveDisclosureStrategy<'a> {
fn finalize_input(&mut self) -> Result<()> {
match self {
SDJWTClaimsStrategy::Partial(keys) => {
ClaimsForSelectiveDisclosureStrategy::Custom(keys) => {
for key in keys.iter_mut() {
if let Some(new_key) = key.strip_prefix("$.") {
*key = new_key;
Expand All @@ -73,37 +76,33 @@ impl<'a> SDJWTClaimsStrategy<'a> {

fn next_level(&self, key: &str) -> Self {
match self {
Self::No => Self::No,
Self::Flat => Self::No,
Self::Full => Self::Full,
Self::Partial(sd_keys) => {
Self::NoSDClaims => Self::NoSDClaims,
Self::TopLevel => Self::NoSDClaims,
Self::AllLevels => Self::AllLevels,
Self::Custom(sd_keys) => {
let next_sd_keys = sd_keys
.iter()
.filter_map(|str| {
str.strip_prefix(key).as_mut().map(|claim| {
if let Some(next_claim) = claim.strip_prefix('.') {
next_claim
} else {
// FIXME Replace to non-leackable impl
// Removes "[", "]" symbols form "index" and returns "next_claim" as "index.remained_claims.."
// For example: [0].street -> 0.street
*claim = claim.replace(['[', ']'], "").leak();
claim
str.strip_prefix(key).and_then(|str|
match str.chars().next() {
Some('.') => Some(&str[1..]), // next token
Some('[') => Some(str), // array index
_ => None,
}
})
)
})
.collect();
Self::Partial(next_sd_keys)
Self::Custom(next_sd_keys)
}
}
}

fn sd_for_key(&self, key: &str) -> bool {
match self {
Self::No => false,
Self::Flat => true,
Self::Full => true,
Self::Partial(sd_keys) => sd_keys.contains(&key),
Self::NoSDClaims => false,
Self::TopLevel => true,
Self::AllLevels => true,
Self::Custom(sd_keys) => sd_keys.contains(&key),
}
}
}
Expand Down Expand Up @@ -149,20 +148,20 @@ impl SDJWTIssuer {
///
/// # Arguments
/// * `user_claims` - The claims to be included in the SD-JWT.
/// * `sd_strategy` - The strategy to be used to determine which claims to be selectively disclosed. See [SDJWTClaimsStrategy] for more details.
/// * `sd_strategy` - The strategy to be used to determine which claims to be selectively disclosed. See [ClaimsForSelectiveDisclosureStrategy] for more details.
/// * `holder_key` - The key used to sign the SD-JWT. If not provided, no key binding is added to the SD-JWT.
/// * `add_decoy_claims` - If true, decoy claims are added to the SD-JWT.
/// * `serialization_format` - The serialization format to be used for the SD-JWT. Only "compact" and "json" formats are supported.
/// * `serialization_format` - The serialization format to be used for the SD-JWT, see [SDJWTSerializationFormat].
///
/// # Returns
/// The issued SD-JWT as a string in the requested serialization format.
pub fn issue_sd_jwt(
&mut self,
user_claims: Value,
mut sd_strategy: SDJWTClaimsStrategy,
mut sd_strategy: ClaimsForSelectiveDisclosureStrategy,
holder_key: Option<Jwk>,
add_decoy_claims: bool,
serialization_format: String,
serialization_format: SDJWTSerializationFormat,
// extra_header_parameters: Option<HashMap<String, String>>,
) -> Result<String> {
let inner = SDJWTCommon {
Expand All @@ -189,7 +188,7 @@ impl SDJWTIssuer {
fn assemble_sd_jwt_payload(
&mut self,
mut user_claims: Value,
sd_strategy: SDJWTClaimsStrategy,
sd_strategy: ClaimsForSelectiveDisclosureStrategy,
) -> Result<()> {
let claims_obj_ref = user_claims
.as_object_mut()
Expand Down Expand Up @@ -221,18 +220,18 @@ impl SDJWTIssuer {
Ok(())
}

fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: SDJWTClaimsStrategy) -> Value {
fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
match user_claims {
Value::Array(list) => self.create_sd_claims_list(list, sd_strategy),
Value::Object(object) => self.create_sd_claims_object(object, sd_strategy),
_ => user_claims.to_owned(),
}
}

fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: SDJWTClaimsStrategy) -> Value {
fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
let mut claims = Vec::new();
for (idx, object) in list.iter().enumerate() {
let key = idx.to_string();
let key = format!("[{idx}]");
let strategy_for_child = sd_strategy.next_level(&key);
let subtree = self.create_sd_claims(object, strategy_for_child);

Expand All @@ -250,7 +249,7 @@ impl SDJWTIssuer {
fn create_sd_claims_object(
&mut self,
user_claims: &SJMap<String, Value>,
sd_strategy: SDJWTClaimsStrategy,
sd_strategy: ClaimsForSelectiveDisclosureStrategy,
) -> Value {
let mut claims = SJMap::new();
let mut sd_claims = Vec::new();
Expand Down Expand Up @@ -308,7 +307,7 @@ impl SDJWTIssuer {
}

fn create_combined(&mut self) -> Result<()> {
if self.inner.serialization_format == "compact" {
if self.inner.serialization_format == SDJWTSerializationFormat::Compact {
let mut disclosures: VecDeque<String> = self
.all_disclosures
.iter()
Expand All @@ -323,7 +322,7 @@ impl SDJWTIssuer {
disclosures.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR),
COMBINED_SERIALIZATION_FORMAT_SEPARATOR,
);
} else if self.inner.serialization_format == "json" {
} else if self.inner.serialization_format == SDJWTSerializationFormat::JSON {
let jwt: Vec<&str> = self.signed_sd_jwt.split('.').collect();
if jwt.len() != 3 {
return Err(Error::InvalidInput(format!(
Expand All @@ -346,7 +345,7 @@ impl SDJWTIssuer {
.map_err(|e| Error::DeserializationError(e.to_string()))?;
} else {
return Err(Error::InvalidInput(
format!("Unknown serialization format {}, only \"compact\" or \"json\" formats are supported", self.inner.serialization_format)
format!("Unknown serialization format {}, only \"Compact\" or \"JSON\" formats are supported", self.inner.serialization_format)
));
}

Expand All @@ -365,8 +364,8 @@ mod tests {
use log::trace;
use serde_json::json;

use crate::issuer::SDJWTClaimsStrategy;
use crate::SDJWTIssuer;
use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
use crate::{SDJWTIssuer, SDJWTSerializationFormat};

const PRIVATE_ISSUER_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgUr2bNKuBPOrAaxsR\nnbSH6hIhmNTxSGXshDSUD1a1y7ihRANCAARvbx3gzBkyPDz7TQIbjF+ef1IsxUwz\nX1KWpmlVv+421F7+c1sLqGk4HUuoVeN8iOoAcE547pJhUEJyf5Asc6pP\n-----END PRIVATE KEY-----\n";

Expand All @@ -388,12 +387,45 @@ mod tests {
let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims,
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
trace!("{:?}", sd_jwt)
}

#[test]
fn test_next_level_array() {
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"name",
"addresses[1]",
"addresses[1].country",
"nationalities[0]",
]);

let next_strategy = strategy.next_level("addresses");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["[1]", "[1].country"]));
let next_strategy = next_strategy.next_level("[1]");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["country"]));
}

#[test]
fn test_next_level_object() {
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"address.street_address",
"address.locality",
"address.region",
"address.country",
]);

let next_strategy = strategy.next_level("address");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"street_address",
"locality",
"region",
"country"
]));
}
}
Loading

0 comments on commit ed88931

Please sign in to comment.