From 7e04c04ae51166e67449c57252cc091e27727ea9 Mon Sep 17 00:00:00 2001 From: Abdulbois Date: Wed, 13 Dec 2023 15:00:13 +0500 Subject: [PATCH] feat: Add sd-jwt-generate crate Signed-off-by: Abdulbois --- generate/Cargo.toml | 14 ++ generate/src/error.rs | 296 ++++++++++++++++++++++++++++ generate/src/main.rs | 171 ++++++++++++++++ generate/src/types/cli.rs | 20 ++ generate/src/types/mod.rs | 3 + generate/src/types/settings.rs | 90 +++++++++ generate/src/types/specification.rs | 144 ++++++++++++++ generate/src/utils/generate.rs | 66 +++++++ generate/src/utils/mod.rs | 1 + 9 files changed, 805 insertions(+) create mode 100644 generate/Cargo.toml create mode 100644 generate/src/error.rs create mode 100644 generate/src/main.rs create mode 100644 generate/src/types/cli.rs create mode 100644 generate/src/types/mod.rs create mode 100644 generate/src/types/settings.rs create mode 100644 generate/src/types/specification.rs create mode 100644 generate/src/utils/generate.rs create mode 100644 generate/src/utils/mod.rs diff --git a/generate/Cargo.toml b/generate/Cargo.toml new file mode 100644 index 0000000..f76864e --- /dev/null +++ b/generate/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sd-jwt-generate" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.4.10", features = ["derive"] } +serde = { version = "1.0.193", features = ["derive"] } +serde_yaml = "0.9.27" +serde_json = "1.0.108" +jsonwebtoken = "9.1" +sd-jwt-rs = {path = "./.."} \ No newline at end of file diff --git a/generate/src/error.rs b/generate/src/error.rs new file mode 100644 index 0000000..9905d78 --- /dev/null +++ b/generate/src/error.rs @@ -0,0 +1,296 @@ +#![allow(unused)] + +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; +use std::result::Result as StdResult; +use serde_json; +use serde_yaml; + +pub type Result = std::result::Result; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ErrorKind { + Input, + IOError, +} + +impl ErrorKind { + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + Self::Input => "Input error", + Self::IOError => "IO error" + } + } +} + +impl Display for ErrorKind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// The standard crate error type +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, + pub cause: Option>, + pub message: Option, + // backtrace (when supported) +} + +impl Error { + pub fn from_msg>(kind: ErrorKind, msg: T) -> Self { + Self { + kind, + cause: None, + message: Some(msg.into()), + } + } + + pub fn from_opt_msg>(kind: ErrorKind, msg: Option) -> Self { + Self { + kind, + cause: None, + message: msg.map(Into::into), + } + } + + #[must_use] + #[inline] + pub const fn kind(&self) -> ErrorKind { + self.kind + } + + #[must_use] + pub fn with_cause>>(mut self, err: T) -> Self { + self.cause = Some(err.into()); + self + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match (self.kind, &self.message) { + (ErrorKind::Input, None) => write!(f, "{:?}", self.kind), + (ErrorKind::Input, Some(msg)) => f.write_str(msg), + (kind, None) => write!(f, "{kind}"), + (kind, Some(msg)) => write!(f, "{kind}: {msg}"), + }?; + if let Some(ref source) = self.cause { + write!(f, " [{source}]")?; + } + Ok(()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause + .as_ref() + .map(|err| unsafe { std::mem::transmute(&**err) }) + } +} + +impl PartialEq for Error { + fn eq(&self, other: &Self) -> bool { + self.kind == other.kind && self.message == other.message + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { + kind, + cause: None, + message: None, + } + } +} + +impl From for Error { + fn from(err: std::io::Error) -> Self { + Self::from(ErrorKind::IOError).with_cause(err) + } +} + +impl From for Error { + fn from(err: serde_json::Error) -> Self { + Self::from(ErrorKind::Input).with_cause(err) + } +} + +impl From for Error { + fn from(err: serde_yaml::Error) -> Self { + Self::from(ErrorKind::Input).with_cause(err) + } +} + +impl From<(ErrorKind, M)> for Error +where + M: fmt::Display + Send + Sync + 'static, +{ + fn from((kind, msg): (ErrorKind, M)) -> Self { + Self::from_msg(kind, msg.to_string()) + } +} + +macro_rules! err_msg { + () => { + $crate::error::Error::from($crate::error::ErrorKind::Input) + }; + ($kind:ident) => { + $crate::error::Error::from($crate::error::ErrorKind::$kind) + }; + ($kind:ident, $($args:tt)+) => { + $crate::error::Error::from_msg($crate::error::ErrorKind::$kind, format!($($args)+)) + }; + ($($args:tt)+) => { + $crate::error::Error::from_msg($crate::error::ErrorKind::Input, format!($($args)+)) + }; +} + +macro_rules! err_map { + ($($params:tt)*) => { + |err| err_msg!($($params)*).with_cause(err) + }; +} + +pub trait ResultExt { + fn map_err_string(self) -> StdResult; + fn map_input_err(self, mapfn: F) -> Result + where + F: FnOnce() -> M, + M: fmt::Display + Send + Sync + 'static; + fn with_err_msg(self, kind: ErrorKind, msg: M) -> Result + where + M: fmt::Display + Send + Sync + 'static; + fn with_input_err(self, msg: M) -> Result + where + M: fmt::Display + Send + Sync + 'static; +} + +impl ResultExt for StdResult +where + E: std::error::Error + Send + Sync + 'static, +{ + fn map_err_string(self) -> StdResult { + self.map_err(|err| err.to_string()) + } + + fn map_input_err(self, mapfn: F) -> Result + where + F: FnOnce() -> M, + M: fmt::Display + Send + Sync + 'static, + { + self.map_err(|err| Error::from_msg(ErrorKind::Input, mapfn().to_string()).with_cause(err)) + } + + fn with_err_msg(self, kind: ErrorKind, msg: M) -> Result + where + M: fmt::Display + Send + Sync + 'static, + { + self.map_err(|err| Error::from_msg(kind, msg.to_string()).with_cause(err)) + } + + #[inline] + fn with_input_err(self, msg: M) -> Result + where + M: fmt::Display + Send + Sync + 'static, + { + self.map_err(|err| Error::from_msg(ErrorKind::Input, msg.to_string()).with_cause(err)) + } +} + +type DynError = Box; + +macro_rules! define_error { + ($name:tt, $short:expr, $doc:tt) => { + #[derive(Debug, Error)] + #[doc=$doc] + pub struct $name { + pub context: Option, + pub source: Option, + } + + impl $name { + pub fn from_msg>(msg: T) -> Self { + Self::from(msg.into()) + } + + pub fn from_err(err: E) -> Self + where + E: StdError + Send + Sync + 'static, + { + Self { + context: None, + source: Some(Box::new(err) as DynError), + } + } + + pub fn from_msg_err(msg: M, err: E) -> Self + where + M: Into, + E: StdError + Send + Sync + 'static, + { + Self { + context: Some(msg.into()), + source: Some(Box::new(err) as DynError), + } + } + } + + impl From<&str> for $name { + fn from(context: &str) -> Self { + Self { + context: Some(context.to_owned()), + source: None, + } + } + } + + impl From for $name { + fn from(context: String) -> Self { + Self { + context: Some(context), + source: None, + } + } + } + + impl From> for $name { + fn from(context: Option) -> Self { + Self { + context, + source: None, + } + } + } + + impl From<(M, E)> for $name + where + M: Into, + E: StdError + Send + Sync + 'static, + { + fn from((context, err): (M, E)) -> Self { + Self::from_msg_err(context, err) + } + } + + impl From<$name> for String { + fn from(s: $name) -> Self { + s.to_string() + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, $short)?; + match self.context { + Some(ref context) => write!(f, ": {}", context), + None => Ok(()), + } + } + } + }; +} \ No newline at end of file diff --git a/generate/src/main.rs b/generate/src/main.rs new file mode 100644 index 0000000..5e0674f --- /dev/null +++ b/generate/src/main.rs @@ -0,0 +1,171 @@ +mod error; +mod types; +mod utils; + +use crate::error::{Error, ErrorKind, Result}; +use clap::Parser; +use jsonwebtoken::EncodingKey; +use sd_jwt_rs::issuer::{SDJWTClaimsStrategy, SDJWTIssuer}; +use sd_jwt_rs::SALTS; +use serde_json::Value; +use std::collections::HashMap; +use std::path::PathBuf; +use types::cli::{Cli, GenerateType}; +use types::settings::Settings; +use types::specification::Specification; + +const ISSUER_KEY_PEM_FILE_NAME: &str = "issuer_key.pem"; +// const HOLDER_KEY_PEM_FILE_NAME: &str = "holder_key.pem"; +const SERIALIZATION_FORMAT: &str = "compact"; +const SETTINGS_FILE_NAME: &str = "settings.yml"; +const SPECIFICATION_FILE_NAME: &str = "specification.yml"; +const SALTS_FILE_NAME: &str = "claims_vs_salts.json"; +const SD_JWT_PAYLOAD_FILE_NAME: &str = "sd_jwt_payload.json"; + +fn main() { + let args = Cli::parse(); + + println!("type_: {:?}, paths: {:?}", args.type_.clone(), args.paths); + + let basedir = std::env::current_dir().expect("Unable to get current directory"); + + let settings = get_settings(&basedir.join(SETTINGS_FILE_NAME)); + + let spec_directories = get_specification_paths(&args, basedir).unwrap(); + + for mut directory in spec_directories { + println!("Generating data for '{:?}'", directory); + let specs = Specification::from(&directory); + // Remove specification.yaml from path + directory.pop(); + + generate_and_check(&directory, &settings, specs, args.type_.clone()).unwrap(); + } +} + +fn generate_and_check( + directory: &PathBuf, + _: &Settings, + specs: Specification, + _: GenerateType, +) -> Result<()> { + // let seed = settings.random_seed.unwrap_or(0); + + // Get keys from .pem files + let issuer_key = get_key(&directory.join(ISSUER_KEY_PEM_FILE_NAME)); + // let holder_key = get_key(key_path.join(HOLDER_KEY_PEM_FILE_NAME)); + + let user_claims = specs.user_claims.claims_to_json_value()?; + let decoy = specs.add_decoy_claims.unwrap_or(false); + let sd_claims_jsonpaths = specs.user_claims.sd_claims_to_jsonpath()?; + + let strategy = + SDJWTClaimsStrategy::Partial(sd_claims_jsonpaths.iter().map(String::as_str).collect()); + + let issuer = SDJWTIssuer::issue_sd_jwt( + user_claims, + strategy, + issuer_key, + None, + None, + decoy, + SERIALIZATION_FORMAT.to_string(), + ); + println!("Issued SD-JWT \n {:#?}", issuer.sd_jwt_payload); + + compare_jwt_payloads( + &directory.join(SD_JWT_PAYLOAD_FILE_NAME), + &issuer.sd_jwt_payload, + ) + + // let mut holder = SDJWTHolder::new( + // issuer.serialized_sd_jwt.clone(), + // SERIALIZATION_FORMAT.to_string(), + // ); + // holder.create_presentation(Some(vec!["address".to_string()]), None, None, None, None); + // println!("Created presentation \n {:?}", holder.sd_jwt_presentation) +} + +fn compare_jwt_payloads(path: &PathBuf, compare: &serde_json::Map) -> Result<()> { + let contents = std::fs::read_to_string(path)?; + + let json_value: serde_json::Map = serde_json::from_str(&contents) + .expect(&format!("Failed to parse to serde_json::Value {:?}", path)); + + if json_value.eq(compare) { + println!("Issued JWT payload is the same as payload of {:?}", path); + } else { + eprintln!( + "Issued JWT payload is NOT the same as payload of {:?}", + path + ); + + println!("Issued SD-JWT \n {:#?}", compare); + println!("Loaded SD-JWT \n {:#?}", json_value); + } + + Ok(()) +} + +fn get_key(path: &PathBuf) -> EncodingKey { + let key = std::fs::read(path).expect("Failed to read file"); + + EncodingKey::from_ec_pem(&key).expect("Unable to create EncodingKey") +} + +fn get_settings(path: &PathBuf) -> Settings { + println!("settings.yaml - {:?}", path); + + let settings = Settings::from(path); + println!("{:#?}", settings); + + settings +} + +fn get_specification_paths(args: &Cli, basedir: PathBuf) -> Result> { + let glob: Vec; + if args.paths.is_empty() { + glob = basedir + .read_dir()? + .filter_map(|entry| { + if let Ok(entry) = entry { + let path = entry.path(); + if path.is_dir() && path.join(SPECIFICATION_FILE_NAME).exists() { + // load_salts(&path).map_err(|err| Error::from_msg(ErrorKind::IOError, err.to_string()))?; + load_salts(&path).unwrap(); + return Some(path.join(SPECIFICATION_FILE_NAME)); + } + } + None + }) + .collect(); + } else { + glob = args + .paths + .iter() + .map(|d| { + // load_salts(&path).map_err(|err| Error::from_msg(ErrorKind::IOError, err.to_string()))?; + load_salts(&d).unwrap(); + basedir.join(d).join(SPECIFICATION_FILE_NAME) + }) + .collect(); + } + + println!("specification.yaml files - {:?}", glob); + + Ok(glob) +} + +fn load_salts(path: &PathBuf) -> Result<()> { + let salts_path = path.join(SALTS_FILE_NAME); + let json_data = std::fs::read_to_string(salts_path) + .map_err(|e| Error::from_msg(ErrorKind::IOError, e.to_string()))?; + let salts: HashMap = serde_json::from_str(&json_data)?; + + { + let mut map = SALTS.lock().unwrap(); + map.extend(salts.into_iter()); + } + + Ok(()) +} diff --git a/generate/src/types/cli.rs b/generate/src/types/cli.rs new file mode 100644 index 0000000..49b7fc3 --- /dev/null +++ b/generate/src/types/cli.rs @@ -0,0 +1,20 @@ +use clap::Parser; +use serde::Serialize; + +#[derive(Parser)] +pub struct Cli { + /// The type to generate + #[arg(short, value_enum, default_value_t = GenerateType::Example)] + pub type_: GenerateType, + /// The paths to the directories where specification.yaml file is located + #[arg(short, value_delimiter = ' ', num_args = 0.., require_equals = false)] + pub paths: Vec, +} + + +#[derive(clap::ValueEnum, Clone, Debug, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum GenerateType { + Example, + TestCase, +} \ No newline at end of file diff --git a/generate/src/types/mod.rs b/generate/src/types/mod.rs new file mode 100644 index 0000000..8bad8a4 --- /dev/null +++ b/generate/src/types/mod.rs @@ -0,0 +1,3 @@ +pub mod settings; +pub mod specification; +pub mod cli; \ No newline at end of file diff --git a/generate/src/types/settings.rs b/generate/src/types/settings.rs new file mode 100644 index 0000000..08ce014 --- /dev/null +++ b/generate/src/types/settings.rs @@ -0,0 +1,90 @@ +use std::path::PathBuf; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct KeySettings { + pub key_size: i32, + pub kty: String, + pub issuer_key: Key, + pub holder_key: Key, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct Key { + pub kty: String, + pub d: String, + pub crv: String, + pub x: String, + pub y: String, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct Identifiers { + pub issuer: String, + pub verifier: String, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug)] +pub struct Settings { + pub identifiers: Identifiers, + pub key_settings: KeySettings, + pub key_binding_nonce: String, + pub expiry_seconds: Option, + pub random_seed: Option, + pub iat: Option, + pub exp: Option, +} + +impl From<&PathBuf> for Settings { + fn from(path: &PathBuf) -> Self { + let contents = std::fs::read_to_string(path) + .expect("Failed to read settings file"); + + let settings: Settings = serde_yaml::from_str(&contents) + .expect("Failed to parse YAML"); + + settings + } +} + +#[cfg(test)] +mod tests { + use crate::types::settings::Settings; + + #[test] + fn test_test_settings() { + let yaml_str = r#" + identifiers: + issuer: "https://example.com/issuer" + verifier: "https://example.com/verifier" + + key_settings: + key_size: 256 + kty: "EC" + issuer_key: + kty: "EC" + d: "Ur2bNKuBPOrAaxsRnbSH6hIhmNTxSGXshDSUD1a1y7g" + crv: "P-256" + x: "b28d4MwZMjw8-00CG4xfnn9SLMVMM19SlqZpVb_uNtQ" + y: "Xv5zWwuoaTgdS6hV43yI6gBwTnjukmFQQnJ_kCxzqk8" + holder_key: + kty: "EC" + d: "5K5SCos8zf9zRemGGUl6yfok-_NiiryNZsvANWMhF-I" + crv: "P-256" + x: "TCAER19Zvu3OHF4j4W4vfSVoHIP1ILilDls7vCeGemc" + y: "ZxjiWWbZMQGHVWKVQ4hbSIirsVfuecCE6t4jT9F2HZQ" + + key_binding_nonce: "1234567890" + + expiry_seconds: 86400000 + random_seed: 0 + iat: 1683000000 + exp: 1883000000 + "#; + + let settings: Settings = serde_yaml::from_str(yaml_str).unwrap(); + println!("{:#?}", settings); + assert_eq!(settings.identifiers.issuer, "https://example.com/issuer"); + } +} + diff --git a/generate/src/types/specification.rs b/generate/src/types/specification.rs new file mode 100644 index 0000000..a210a4c --- /dev/null +++ b/generate/src/types/specification.rs @@ -0,0 +1,144 @@ +use crate::utils::generate::generate_jsonpath_from_tagged_values; +use serde::{Deserialize, Serialize}; +use serde_yaml::Value; +use std::collections::HashMap; +use std::path::PathBuf; +use crate::error::Result; + +const SD_TAG: &str = "!sd"; + +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] +pub struct Specification { + pub user_claims: UserClaims, + pub holder_disclosed_claims: HashMap, + pub add_decoy_claims: Option, + pub key_binding: Option, +} + +impl From<&str> for Specification { + fn from(value: &str) -> Self { + serde_yaml::from_str(value).unwrap_or(Specification::default()) + } +} + +impl From<&PathBuf> for Specification { + fn from(path: &PathBuf) -> Self { + let contents = std::fs::read_to_string(path).expect("Failed to read specification file"); + + let spec: Specification = serde_yaml::from_str(&contents).expect("Failed to parse YAML"); + + spec + } +} + +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] +pub struct UserClaims(HashMap); + +impl UserClaims { + pub fn claims_to_json_value(&self) -> Result { + let value = serde_yaml::to_value(&self.0) + .expect("Failed to convert user-claims into serde_yaml::Value"); + let filtered_value = _remove_tags(&value); + let json_value: serde_json::Value = + serde_yaml::from_value(filtered_value).expect("Failed to convert serde_json::Value"); + + Ok(json_value) + } + + pub fn sd_claims_to_jsonpath(&self) -> Result> { + let mut path = "".to_string(); + let mut paths = Vec::new(); + let mut claims = serde_yaml::to_value(&self.0)?; + + let _ = generate_jsonpath_from_tagged_values(&mut claims, &mut path, &mut paths); + + Ok(paths) + } +} + +fn _validate(value: &Value) -> Result<()> { + match value { + Value::String(_) | Value::Bool(_) | Value::Number(_) => Ok(()), + Value::Tagged(tag) => { + if tag.tag.to_string() == SD_TAG { + _validate(&tag.value) + } else { + panic!( + "Unsupported tag {:?} in claim-name, only !sd tag is supported", + tag.tag + ); + } + } + Value::Sequence(list) => { + for v in list { + _validate(v)?; + } + + Ok(()) + } + Value::Mapping(map) => { + for (key, value) in map { + _validate(key)?; + _validate(value)?; + } + + Ok(()) + } + + _ => { + panic!("Unsupported type for claim-name, it can be only string or tagged"); + } + } +} + +fn _remove_tags(original: &Value) -> Value { + match original { + Value::Tagged(tag) => _remove_tags(&tag.value), + Value::Mapping(map) => { + let mut filtered_map = serde_yaml::Mapping::new(); + + for (key, value) in map.iter() { + match key { + Value::Tagged(tag) => { + let filtered_value = _remove_tags(value); + + filtered_map.insert(tag.value.clone(), filtered_value); + } + Value::Null => {} + _ => { + let filtered_value = _remove_tags(value); + filtered_map.insert(key.clone(), filtered_value); + } + } + } + + Value::Mapping(filtered_map) + } + Value::Sequence(seq) => { + let filtered_seq: Vec = seq.iter().map(|v| _remove_tags(v)).collect(); + + Value::Sequence(filtered_seq) + } + other => other.clone(), + } +} +#[cfg(test)] +mod tests { + use crate::types::specification::Specification; + + #[test] + fn test_specification() { + let yaml_str = r#" + user_claims: + sub: 6c5c0a49-b589-431d-bae7-219122a9ec2c + !sd address: + street_address: Schulstr. 12 + !sd street_address1: Schulstr. 12 + + holder_disclosed_claims: {} + "#; + + let spec = Specification::from(yaml_str); + println!("{:?}", spec.user_claims.claims_to_json_value().unwrap()) + } +} diff --git a/generate/src/utils/generate.rs b/generate/src/utils/generate.rs new file mode 100644 index 0000000..831f318 --- /dev/null +++ b/generate/src/utils/generate.rs @@ -0,0 +1,66 @@ +use serde_yaml::Value; +use crate::error::Result; + +#[allow(unused)] +pub fn generate_jsonpath_from_tagged_values( + yaml: &Value, + path: &mut String, + paths: &mut Vec, +) -> Result<()> { + match yaml { + Value::Mapping(map) => { + for (key, value) in map { + let len = path.len(); + + if path.is_empty() { + path.push_str("$."); + } + // Handle nested + match key { + Value::Tagged(tagged) => { + path.push_str(tagged.value.as_str().unwrap()); + + match value { + Value::Mapping(_) => { + path.push('.'); + generate_jsonpath_from_tagged_values(value, path, paths); + } + Value::Sequence(_) => { + generate_jsonpath_from_tagged_values(value, path, paths); + } + _ => {}, + } + + if path.ends_with('.') { + path.pop().unwrap(); + } + + paths.push(path.clone()); + } + Value::String(s) => { + path.push_str(s); + path.push('.'); + + generate_jsonpath_from_tagged_values(value, path, paths); + } + _ => {} + } + + path.truncate(len); + } + } + Value::Sequence(seq) => { + for (idx, value) in seq.iter().enumerate() { + let len = path.len(); + + path.push_str(&format!("[{}].", idx)); + generate_jsonpath_from_tagged_values(value, path, paths); + + path.truncate(len); + } + } + _ => {} + } + + Ok(()) +} diff --git a/generate/src/utils/mod.rs b/generate/src/utils/mod.rs new file mode 100644 index 0000000..118c66d --- /dev/null +++ b/generate/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod generate; \ No newline at end of file