Skip to content

Commit

Permalink
Switch to jwtk (#1)
Browse files Browse the repository at this point in the history
* switch to jwtk

* update changelog
  • Loading branch information
tizz98 authored Aug 12, 2023
1 parent 6e95d34 commit ba5494a
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 106 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Changed
- **Breaking change**: `decode` now borrows the XQR instead of taking ownership of it
- This allows the XQR to be reused after decoding
- **Breaking changes**:
- `decode` now borrows the XQR instead of taking ownership of it
- This allows the XQR to be reused after decoding
- `encode` takes an issuer rather than key id
- `fetch_public_key` takes an issuer in addition to key id
- This allows the issuer to be used to fetch the public key and look up by key id
- Change from `jwt-simple` to `jwtk`

## [0.3.0] - 2023-08-11
### Added
Expand Down
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@ keywords = ["secure", "qr", "code"]
serde = {version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_derive = "1.0"
jwt-simple = "0.11.6"
reqwest = {version = "0.11.18", features = ["blocking", "json"]}
jwtk = "0.2.4"
url = "2.4.0"
base64 = "0.21.2"
218 changes: 115 additions & 103 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use jwt_simple::prelude::*;
use jwt_simple::Error;
use base64::engine::general_purpose::NO_PAD;
use base64::engine::{GeneralPurpose, GeneralPurposeConfig};
use base64::{alphabet, Engine as _};
use jwtk::jwk::WithKid;
use jwtk::{decode_without_verify, ecdsa, sign, verify, HeaderAndClaims};
use reqwest;
use serde_derive::{Deserialize, Serialize};
use serde_json::Value;
use std::error::Error;

const NO_PAD_TRAILING_BITS: GeneralPurposeConfig = NO_PAD.with_decode_allow_trailing_bits(true);
const URL_SAFE_NO_PAD: GeneralPurpose =
GeneralPurpose::new(&alphabet::URL_SAFE, NO_PAD_TRAILING_BITS);

/// Represents the extended quick response (XQR) code, encapsulating the JWT token.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
Expand All @@ -17,8 +25,20 @@ impl XQR {
///
/// An Option containing the Key ID as a string if present, or None if not found.
pub fn get_kid(&self) -> Option<String> {
match Token::decode_metadata(&self.token) {
Ok(metadata) => metadata.key_id().map(|s| s.to_string()),
match decode_without_verify::<XQRClaims>(&self.token) {
Ok(header) => header.header().kid.clone().map(|s| s.to_string()),
Err(_) => None,
}
}

/// Returns the value from the JWT token contained in the XQR structure.
///
/// # Returns
///
/// An Option containing the value as a string if present, or None if not found.
pub fn get_iss(&self) -> Option<String> {
match decode_without_verify::<XQRClaims>(&self.token) {
Ok(header) => header.claims().iss.clone().map(|s| s.to_string()),
Err(_) => None,
}
}
Expand Down Expand Up @@ -59,28 +79,23 @@ pub struct XQRClaims {
pub fn encode(
private_key_pem: &str,
value: &str,
kid: &str,
valid_for: Option<Duration>,
) -> Result<XQR, Error> {
let key_pair = ES256KeyPair::from_pem(private_key_pem)?;
let key_pair = key_pair.with_key_id(kid);
let initial_duration = match valid_for {
Some(duration) => duration,
// with_custom_claims requires a non-None duration, so we use 0 if valid_for is None.
// After creating the claims, we'll set the expires_at value to None.
None => Duration::from_hours(0),
};
let mut claims = Claims::with_custom_claims(
XQRClaims {
value: value.to_string(),
},
initial_duration,
);
if valid_for.is_none() {
claims.expires_at = None;
iss: &str,
valid_for: Option<std::time::Duration>,
) -> jwtk::Result<XQR> {
let private_key = ecdsa::EcdsaPrivateKey::from_pem(private_key_pem.as_ref())?;
let private_key = WithKid::new_with_thumbprint_id(private_key)?;

let mut claims = HeaderAndClaims::new_dynamic();
let claims = claims
.insert("value", value)
.set_iss(iss)
.set_iat_now()
.set_nbf_from_now(std::time::Duration::from_secs(0));
if valid_for.is_some() {
claims.set_exp_from_now(valid_for.unwrap());
}
let token = key_pair.sign(claims)?;

let token = sign(claims, &private_key)?;
Ok(XQR { token })
}

Expand All @@ -94,28 +109,26 @@ pub fn encode(
/// # Returns
///
/// A Result containing the decoded value as a String or an error if the operation fails.
pub fn decode(public_key_pem: &str, xqr: &XQR) -> Result<String, Error> {
let public_key = ES256PublicKey::from_pem(public_key_pem)?;
let claims = public_key
.verify_token::<XQRClaims>(&xqr.token, None)?
.custom;

Ok(claims.value)
pub fn decode(public_key_pem: &str, xqr: &XQR) -> jwtk::Result<String> {
let public_key = ecdsa::EcdsaPublicKey::from_pem(public_key_pem.as_ref())?;
let verified = verify::<XQRClaims>(&xqr.token, &public_key)?;
Ok(verified.claims().extra.value.clone())
}

/// Fetches the public key based on the key ID.
///
/// # Arguments
///
/// * `issuer` - The issuer URL (e.g. https://example.com, https://demo.xqr.dev).
/// * `key_id` - The key ID in the format "example.com#123".
///
/// # Returns
///
/// A Result containing the public key as a string in PEM format or an error if the operation fails.
pub fn fetch_public_key(key_id: &str) -> Result<String, Box<dyn std::error::Error>> {
// Extract the URL from the key_id
let url_parts: Vec<&str> = key_id.split('#').collect();
let url = format!("https://{}/.well-known/jwks.json", url_parts[0]);
pub fn fetch_public_key(issuer: &str, key_id: &str) -> Result<String, Box<dyn Error>> {
let domain = url::Url::parse(issuer)?;
let domain = domain.host_str().unwrap();
let url = format!("https://{}/.well-known/jwks.json", domain);

// Make the HTTP request
let response = reqwest::blocking::get(&url)?;
Expand All @@ -125,12 +138,12 @@ pub fn fetch_public_key(key_id: &str) -> Result<String, Box<dyn std::error::Erro
if let Some(keys) = jwks["keys"].as_array() {
for key in keys {
if key["kid"].as_str() == Some(key_id) {
// Extract and return the public key in your desired format (e.g., PEM)
// The actual extraction may vary depending on the JWKS structure
return Ok(key["x5c"].as_array().unwrap()[0]
.as_str()
.unwrap()
.to_string());
let pub_key = ecdsa::EcdsaPublicKey::from_coordinates(
&URL_SAFE_NO_PAD.decode(key["x"].as_str().unwrap())?,
&URL_SAFE_NO_PAD.decode(key["y"].as_str().unwrap())?,
ecdsa::EcdsaAlgorithm::ES256,
)?;
return Ok(pub_key.to_pem()?);
}
}
}
Expand All @@ -141,65 +154,69 @@ pub fn fetch_public_key(key_id: &str) -> Result<String, Box<dyn std::error::Erro
)))
}

/// Generates a new ECDSA (ES256) key pair for use with JWT tokens.
/// Generates a new ECDSA (ES256) private key for use with JWT tokens.
///
/// # Returns
///
/// The generated ES256 key pair.
pub fn generate_key_pair() -> ES256KeyPair {
ES256KeyPair::generate()
/// The generated ES256 private key.
pub fn generate_key() -> jwtk::Result<ecdsa::EcdsaPrivateKey> {
ecdsa::EcdsaPrivateKey::generate(ecdsa::EcdsaAlgorithm::ES256)
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;

#[test]
fn encode_decode_test() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let public_key = key_pair.public_key().to_pem().unwrap();
let value = "value";
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();
let public_key = key.public_key_to_pem().unwrap();

let encoded_xqr = encode(&private_key, value, "example.com#123", None).unwrap();
let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();
let decoded_value = decode(&public_key, &encoded_xqr).unwrap();

assert_eq!(decoded_value, value);
assert_eq!(decoded_value, "value");
}

#[test]
fn decode_with_wrong_pub_key_fails() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let public_key = generate_key_pair().public_key().to_pem().unwrap();
let value = "value";
let kid = "example.com#123";
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();
let public_key = generate_key().unwrap().public_key_to_pem().unwrap();

let encoded_xqr = encode(&private_key, value, kid, None).unwrap();
let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();
let decoded_value = decode(&public_key, &encoded_xqr);

assert!(decoded_value.is_err());
}

#[test]
fn get_kid_test() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let value = "value";
let kid = "example.com#123";
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();

let encoded_xqr = encode(&private_key, value, kid, None).unwrap();
let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();

assert_eq!(encoded_xqr.get_kid().unwrap(), kid);
assert!(encoded_xqr.get_kid().is_some());
}

#[test]
fn pem_serialization_test() {
let key_pair = generate_key_pair();
fn get_iss_test() {
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();

let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();

assert_eq!(encoded_xqr.get_iss().unwrap(), "https://example.com");
}

// Convert keys to PEM
let private_pem = key_pair.to_pem().unwrap();
let public_pem = key_pair.public_key().to_pem().unwrap();
#[test]
fn pem_serialization_test() {
let key = generate_key().unwrap();
let private_pem = key.private_key_to_pem_pkcs8().unwrap();
let public_pem = key.public_key_to_pem().unwrap();

// Verify that the PEM strings contain the correct headers
assert!(private_pem.contains("-----BEGIN PRIVATE KEY-----"));
Expand All @@ -210,59 +227,54 @@ mod tests {

#[test]
fn xqr_to_string_ergonomics() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let value = "value";
let kid = "example.com#123";
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();

let encoded_xqr = encode(&private_key, value, kid, None).unwrap();
let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();

assert_eq!(encoded_xqr.to_string(), encoded_xqr.token);
}

#[test]
fn xqr_from_string_ergonomics() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let value = "value";
let kid = "example.com#123";
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();

let encoded_xqr = encode(&private_key, value, kid, None).unwrap();
let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();
let encoded_xqr_string = encoded_xqr.to_string();

assert_eq!(XQR::from(encoded_xqr_string), encoded_xqr);
}

#[test]
fn expiration_is_not_set_when_valid_for_is_none() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let pub_key = key_pair.public_key();
let value = "value";
let kid = "example.com#123";

let encoded_xqr = encode(&private_key, value, kid, None).unwrap();
let claims = pub_key
.verify_token::<XQRClaims>(&encoded_xqr.token, None)
.unwrap();

assert!(claims.expires_at.is_none());
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();
let public_key = key.public_key_to_pem().unwrap();
let public_key = ecdsa::EcdsaPublicKey::from_pem(public_key.as_ref()).unwrap();

let encoded_xqr = encode(&private_key, "value", "https://example.com", None).unwrap();
let claims = verify::<XQRClaims>(&encoded_xqr.token, &public_key).unwrap();

assert!(claims.claims().exp.is_none());
}

#[test]
fn expiration_is_set_when_valid_for_is_not_none() {
let key_pair = generate_key_pair();
let private_key = key_pair.to_pem().unwrap();
let pub_key = key_pair.public_key();
let value = "value";
let kid = "example.com#123";
let valid_for = Duration::from_secs(60);

let encoded_xqr = encode(&private_key, value, kid, Some(valid_for)).unwrap();
let claims = pub_key
.verify_token::<XQRClaims>(&encoded_xqr.token, None)
.unwrap();

assert!(claims.expires_at.is_some());
let key = generate_key().unwrap();
let private_key = key.private_key_to_pem_pkcs8().unwrap();
let public_key = key.public_key_to_pem().unwrap();
let public_key = ecdsa::EcdsaPublicKey::from_pem(public_key.as_ref()).unwrap();

let encoded_xqr = encode(
&private_key,
"value",
"https://example.com",
Some(Duration::from_secs(60)),
)
.unwrap();
let claims = verify::<XQRClaims>(&encoded_xqr.token, &public_key).unwrap();

assert!(claims.claims().exp.is_some());
}
}

0 comments on commit ba5494a

Please sign in to comment.