Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removed the 'leak' function call, updated enums #21

Merged
merged 7 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Proposals about API improvements are highly appreciated.
```rust
fn demo() {
let mut issuer = SDJWTIssuer::new(issuer_key, None);
let sd_jwt = issuer.issue_sd_jwt(claims, SDJWTClaimsStrategy::Full, holder_key, add_decoy, "compact".to_owned()).unwrap();
let sd_jwt = issuer.issue_sd_jwt(claims, ClaimsForSelectiveDisclosureStrategy::AllLevels, holder_key, add_decoy, "compact".to_owned()).unwrap();

let mut holder = SDJWTHolder::new(sd_jwt, "compact".to_owned()).unwrap();
let presentation = holder.create_presentation(claims_to_disclosure, None, None, None, None).unwrap();
Expand Down
40 changes: 16 additions & 24 deletions src/holder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{error, SDJWTJson};
use crate::{error, SDJWTJson, SDJWTSerializationFormat};
use error::{Error, Result};
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use serde_json::{Map, Value};
Expand Down Expand Up @@ -29,7 +29,7 @@ impl SDJWTHolder {
///
/// # Arguments
/// * `sd_jwt_with_disclosures` - SD JWT with disclosures in the format specified by `serialization_format`
/// * `serialization_format` - Serialization format of the SD JWT. Supported values are `compact` and `json`
/// * `serialization_format` - Serialization format of the SD JWT, see [SDJWTSerializationFormat].
///
/// # Returns
/// * `SDJWTHolder` - Instance of SDJWTHolder
Expand All @@ -38,18 +38,10 @@ impl SDJWTHolder {
/// * `InvalidInput` - If the serialization format is not supported
/// * `InvalidState` - If the SD JWT data is not valid
/// * `DeserializationError` - If the SD JWT serialization is not valid
pub fn new(sd_jwt_with_disclosures: String, serialization_format: String) -> Result<Self> {
let serialization_format = serialization_format.to_lowercase();
if serialization_format != "compact" && serialization_format != "json" {
return Err(Error::InvalidInput(format!(
"Serialization format \"{}\" is not supported",
serialization_format
)));
}

pub fn new(sd_jwt_with_disclosures: String, serialization_format: SDJWTSerializationFormat) -> Result<Self> {
let mut holder = SDJWTHolder {
sd_jwt_engine: SDJWTCommon {
serialization_format: serialization_format.clone(),
serialization_format,
..Default::default()
},
hs_disclosures: Vec::new(),
Expand Down Expand Up @@ -119,7 +111,7 @@ impl SDJWTHolder {
}
}

let sd_jwt_presentation = if self.sd_jwt_engine.serialization_format == "compact" {
let sd_jwt_presentation = if self.sd_jwt_engine.serialization_format == SDJWTSerializationFormat::Compact {
let mut combined: Vec<&str> = Vec::with_capacity(self.hs_disclosures.len() + 2);
combined.push(&self.serialized_sd_jwt);
combined.extend(self.hs_disclosures.iter().map(|s| s.as_str()));
Expand Down Expand Up @@ -349,8 +341,8 @@ impl SDJWTHolder {

#[cfg(test)]
mod tests {
use crate::issuer::SDJWTClaimsStrategy;
use crate::{SDJWTHolder, SDJWTIssuer, COMBINED_SERIALIZATION_FORMAT_SEPARATOR};
use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
use crate::{SDJWTHolder, SDJWTIssuer, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, SDJWTSerializationFormat};
use jsonwebtoken::EncodingKey;
use serde_json::{json, Map, Value};
use std::collections::HashSet;
Expand All @@ -375,15 +367,15 @@ mod tests {
let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims.clone(),
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
let presentation = SDJWTHolder::new(
sd_jwt.clone(),
"compact".to_ascii_lowercase(),
SDJWTSerializationFormat::Compact,
)
.unwrap()
.create_presentation(
Expand Down Expand Up @@ -415,16 +407,16 @@ mod tests {

let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims.clone(),
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
let issued = sd_jwt.clone();
user_claims["address"] = Value::Object(Map::new());
let presentation =
SDJWTHolder::new(sd_jwt, "compact".to_ascii_lowercase())
SDJWTHolder::new(sd_jwt, SDJWTSerializationFormat::Compact)
.unwrap()
.create_presentation(
user_claims.as_object().unwrap().clone(),
Expand Down Expand Up @@ -472,7 +464,7 @@ mod tests {
]
}
);
let strategy = SDJWTClaimsStrategy::Partial(vec![
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"$.name",
"$.addresses[1]",
"$.addresses[1].country",
Expand All @@ -486,7 +478,7 @@ mod tests {
strategy,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
// Choose what to reveal
Expand All @@ -496,7 +488,7 @@ mod tests {
let issued = sd_jwt.clone();
println!("{}", issued);
let presentation =
SDJWTHolder::new(sd_jwt, "compact".to_ascii_lowercase())
SDJWTHolder::new(sd_jwt, SDJWTSerializationFormat::Compact)
.unwrap()
.create_presentation(
user_claims.as_object().unwrap().clone(),
Expand Down
128 changes: 80 additions & 48 deletions src/issuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::utils::{base64_hash, generate_salt};
use crate::{
SDJWTCommon, CNF_KEY, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG,
DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, JWK_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX,
SDJWTSerializationFormat,
};

pub struct SDJWTIssuer {
Expand All @@ -36,28 +37,30 @@ pub struct SDJWTIssuer {
pub serialized_sd_jwt: String,
}

/// SDJWTClaimsStrategy is used to determine which claims can be selectively disclosed later by the holder.
pub enum SDJWTClaimsStrategy<'a> {
/// No claims can be selectively disclosed. Full disclosure.
No,
/// Top-level claims can be selectively disclosed, nested objects are fully disclosed.
Flat,
/// ClaimsForSelectiveDisclosureStrategy is used to determine which claims can be selectively disclosed later by the holder.
#[derive(PartialEq, Debug)]
pub enum ClaimsForSelectiveDisclosureStrategy<'a> {
/// No claims can be selectively disclosed, so all claims are always disclosed in presentations generated by the holder.
NoSDClaims,
/// Top-level claims can be selectively disclosed, nested objects are fully disclosed, if a parent claim is disclosed.
TopLevel,
/// All claims can be selectively disclosed (recursively including nested objects).
Full,
/// Claims can be selectively disclosed based on the provided JSONPaths. Other claims are fully disclosed.
AllLevels,
/// Claims can be selectively disclosed based on the provided JSONPaths.
/// Other claims are always disclosed in presentation generated by the holder.
/// # Examples
/// ```
/// use sd_jwt_rs::issuer::SDJWTClaimsStrategy;
/// use sd_jwt_rs::issuer::ClaimsForSelectiveDisclosureStrategy;
///
/// let strategy = SDJWTClaimsStrategy::Partial(vec!["$.address", "$.address.street_address"]);
/// let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec!["$.address", "$.address.street_address"]);
/// ```
Partial(Vec<&'a str>),
Custom(Vec<&'a str>),
}

impl<'a> SDJWTClaimsStrategy<'a> {
impl<'a> ClaimsForSelectiveDisclosureStrategy<'a> {
fn finalize_input(&mut self) -> Result<()> {
match self {
SDJWTClaimsStrategy::Partial(keys) => {
ClaimsForSelectiveDisclosureStrategy::Custom(keys) => {
for key in keys.iter_mut() {
if let Some(new_key) = key.strip_prefix("$.") {
*key = new_key;
Expand All @@ -73,37 +76,33 @@ impl<'a> SDJWTClaimsStrategy<'a> {

fn next_level(&self, key: &str) -> Self {
match self {
Self::No => Self::No,
Self::Flat => Self::No,
Self::Full => Self::Full,
Self::Partial(sd_keys) => {
Self::NoSDClaims => Self::NoSDClaims,
Self::TopLevel => Self::NoSDClaims,
Self::AllLevels => Self::AllLevels,
Self::Custom(sd_keys) => {
let next_sd_keys = sd_keys
.iter()
.filter_map(|str| {
str.strip_prefix(key).as_mut().map(|claim| {
if let Some(next_claim) = claim.strip_prefix('.') {
next_claim
} else {
// FIXME Replace to non-leackable impl
// Removes "[", "]" symbols form "index" and returns "next_claim" as "index.remained_claims.."
// For example: [0].street -> 0.street
*claim = claim.replace(['[', ']'], "").leak();
claim
str.strip_prefix(key).and_then(|str|
match str.chars().next() {
Some('.') => Some(&str[1..]), // next token
Some('[') => Some(str), // array index
_ => None,
}
})
)
})
.collect();
Self::Partial(next_sd_keys)
Self::Custom(next_sd_keys)
}
}
}

fn sd_for_key(&self, key: &str) -> bool {
match self {
Self::No => false,
Self::Flat => true,
Self::Full => true,
Self::Partial(sd_keys) => sd_keys.contains(&key),
Self::NoSDClaims => false,
Self::TopLevel => true,
Self::AllLevels => true,
Self::Custom(sd_keys) => sd_keys.contains(&key),
}
}
}
Expand Down Expand Up @@ -149,20 +148,20 @@ impl SDJWTIssuer {
///
/// # Arguments
/// * `user_claims` - The claims to be included in the SD-JWT.
/// * `sd_strategy` - The strategy to be used to determine which claims to be selectively disclosed. See [SDJWTClaimsStrategy] for more details.
/// * `sd_strategy` - The strategy to be used to determine which claims to be selectively disclosed. See [ClaimsForSelectiveDisclosureStrategy] for more details.
/// * `holder_key` - The key used to sign the SD-JWT. If not provided, no key binding is added to the SD-JWT.
/// * `add_decoy_claims` - If true, decoy claims are added to the SD-JWT.
/// * `serialization_format` - The serialization format to be used for the SD-JWT. Only "compact" and "json" formats are supported.
/// * `serialization_format` - The serialization format to be used for the SD-JWT, see [SDJWTSerializationFormat].
///
/// # Returns
/// The issued SD-JWT as a string in the requested serialization format.
pub fn issue_sd_jwt(
&mut self,
user_claims: Value,
mut sd_strategy: SDJWTClaimsStrategy,
mut sd_strategy: ClaimsForSelectiveDisclosureStrategy,
holder_key: Option<Jwk>,
add_decoy_claims: bool,
serialization_format: String,
serialization_format: SDJWTSerializationFormat,
// extra_header_parameters: Option<HashMap<String, String>>,
) -> Result<String> {
let inner = SDJWTCommon {
Expand All @@ -189,7 +188,7 @@ impl SDJWTIssuer {
fn assemble_sd_jwt_payload(
&mut self,
mut user_claims: Value,
sd_strategy: SDJWTClaimsStrategy,
sd_strategy: ClaimsForSelectiveDisclosureStrategy,
) -> Result<()> {
let claims_obj_ref = user_claims
.as_object_mut()
Expand Down Expand Up @@ -221,18 +220,18 @@ impl SDJWTIssuer {
Ok(())
}

fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: SDJWTClaimsStrategy) -> Value {
fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
match user_claims {
Value::Array(list) => self.create_sd_claims_list(list, sd_strategy),
Value::Object(object) => self.create_sd_claims_object(object, sd_strategy),
_ => user_claims.to_owned(),
}
}

fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: SDJWTClaimsStrategy) -> Value {
fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
let mut claims = Vec::new();
for (idx, object) in list.iter().enumerate() {
let key = idx.to_string();
let key = format!("[{idx}]");
let strategy_for_child = sd_strategy.next_level(&key);
let subtree = self.create_sd_claims(object, strategy_for_child);

Expand All @@ -250,7 +249,7 @@ impl SDJWTIssuer {
fn create_sd_claims_object(
&mut self,
user_claims: &SJMap<String, Value>,
sd_strategy: SDJWTClaimsStrategy,
sd_strategy: ClaimsForSelectiveDisclosureStrategy,
) -> Value {
let mut claims = SJMap::new();
let mut sd_claims = Vec::new();
Expand Down Expand Up @@ -308,7 +307,7 @@ impl SDJWTIssuer {
}

fn create_combined(&mut self) -> Result<()> {
if self.inner.serialization_format == "compact" {
if self.inner.serialization_format == SDJWTSerializationFormat::Compact {
let mut disclosures: VecDeque<String> = self
.all_disclosures
.iter()
Expand All @@ -323,7 +322,7 @@ impl SDJWTIssuer {
disclosures.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR),
COMBINED_SERIALIZATION_FORMAT_SEPARATOR,
);
} else if self.inner.serialization_format == "json" {
} else if self.inner.serialization_format == SDJWTSerializationFormat::JSON {
let jwt: Vec<&str> = self.signed_sd_jwt.split('.').collect();
if jwt.len() != 3 {
return Err(Error::InvalidInput(format!(
Expand All @@ -346,7 +345,7 @@ impl SDJWTIssuer {
.map_err(|e| Error::DeserializationError(e.to_string()))?;
} else {
return Err(Error::InvalidInput(
format!("Unknown serialization format {}, only \"compact\" or \"json\" formats are supported", self.inner.serialization_format)
format!("Unknown serialization format {}, only \"Compact\" or \"JSON\" formats are supported", self.inner.serialization_format)
));
}

Expand All @@ -365,8 +364,8 @@ mod tests {
use log::trace;
use serde_json::json;

use crate::issuer::SDJWTClaimsStrategy;
use crate::SDJWTIssuer;
use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
use crate::{SDJWTIssuer, SDJWTSerializationFormat};

const PRIVATE_ISSUER_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgUr2bNKuBPOrAaxsR\nnbSH6hIhmNTxSGXshDSUD1a1y7ihRANCAARvbx3gzBkyPDz7TQIbjF+ef1IsxUwz\nX1KWpmlVv+421F7+c1sLqGk4HUuoVeN8iOoAcE547pJhUEJyf5Asc6pP\n-----END PRIVATE KEY-----\n";

Expand All @@ -388,12 +387,45 @@ mod tests {
let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
user_claims,
SDJWTClaimsStrategy::Full,
ClaimsForSelectiveDisclosureStrategy::AllLevels,
None,
false,
"compact".to_owned(),
SDJWTSerializationFormat::Compact,
)
.unwrap();
trace!("{:?}", sd_jwt)
}

#[test]
fn test_next_level_array() {
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"name",
"addresses[1]",
"addresses[1].country",
"nationalities[0]",
]);

let next_strategy = strategy.next_level("addresses");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["[1]", "[1].country"]));
let next_strategy = next_strategy.next_level("[1]");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["country"]));
}

#[test]
fn test_next_level_object() {
let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"address.street_address",
"address.locality",
"address.region",
"address.country",
]);

let next_strategy = strategy.next_level("address");
assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec![
"street_address",
"locality",
"region",
"country"
]));
}
}
Loading