Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors #59

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ffi::CString;
use std::str::FromStr;
use thiserror::Error;

use crate::signature::SignatureAlgorithm;
#[cfg(feature = "xmlsec")]
use crate::xmlsec::{self, XmlSecKey, XmlSecKeyFormat, XmlSecSignatureContext};
#[cfg(feature = "xmlsec")]
Expand Down Expand Up @@ -223,6 +224,7 @@ fn get_elements_by_predicate<F: FnMut(&libxml::tree::Node) -> bool>(
/// Searches for and returns the element with the given value of the `ID` attribute from the subtree
/// rooted at the given node.
#[cfg(feature = "xmlsec")]
#[allow(unused)]
fn get_element_by_id(elem: &libxml::tree::Node, id: &str) -> Option<libxml::tree::Node> {
let mut elems = get_elements_by_predicate(elem, |node| {
node.get_attribute("ID")
Expand Down Expand Up @@ -486,24 +488,6 @@ pub fn gen_saml_assertion_id() -> String {
format!("_{}", uuid::Uuid::new_v4())
}

#[derive(Debug, PartialEq)]
enum SigAlg {
Unimplemented,
RsaSha256,
EcdsaSha256,
}

impl FromStr for SigAlg {
type Err = Box<dyn std::error::Error>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => Ok(SigAlg::RsaSha256),
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => Ok(SigAlg::EcdsaSha256),
_ => Ok(SigAlg::Unimplemented),
}
}
}

#[derive(Debug, Error, Clone)]
pub enum UrlVerifierError {
#[error("Unimplemented SigAlg: {:?}", sigalg)]
Expand Down Expand Up @@ -621,11 +605,9 @@ impl UrlVerifier {
.collect::<HashMap<String, String>>();

// Match against implemented SigAlg
let sig_alg: SigAlg = SigAlg::from_str(&query_params["SigAlg"])?;
if sig_alg == SigAlg::Unimplemented {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented {
sigalg: query_params["SigAlg"].clone(),
}));
let sig_alg = SignatureAlgorithm::from_str(&query_params["SigAlg"])?;
if let SignatureAlgorithm::Unsupported(sigalg) = sig_alg {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented { sigalg }));
}

// Construct a Url so that percent encoded query can be easily
Expand Down Expand Up @@ -668,13 +650,13 @@ impl UrlVerifier {
fn verify_signature(
&self,
data: &[u8],
sig_alg: SigAlg,
sig_alg: SignatureAlgorithm,
signature: &[u8],
) -> Result<bool, Box<dyn std::error::Error>> {
let mut verifier = openssl::sign::Verifier::new(
match sig_alg {
SigAlg::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SigAlg::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
_ => panic!("sig_alg is bad!"),
},
&self.public_key,
Expand Down
52 changes: 40 additions & 12 deletions src/idp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ pub mod verified_request;
mod tests;

use openssl::bn::{BigNum, MsbOption};
use openssl::ec::{EcGroup, EcKey};
use openssl::nid::Nid;
use openssl::pkey::Private;
use openssl::{asn1::Asn1Time, pkey, rsa::Rsa, x509};
use openssl::{asn1::Asn1Time, pkey, x509};
use std::str::FromStr;

use crate::crypto::{self};
Expand All @@ -24,22 +25,31 @@ pub struct IdentityProvider {
private_key: pkey::PKey<Private>,
}

pub enum KeyType {
pub enum Rsa {
Rsa2048,
Rsa3072,
Rsa4096,
}

impl KeyType {
impl Rsa {
fn bit_length(&self) -> u32 {
match &self {
KeyType::Rsa2048 => 2048,
KeyType::Rsa3072 => 3072,
KeyType::Rsa4096 => 4096,
Rsa::Rsa2048 => 2048,
Rsa::Rsa3072 => 3072,
Rsa::Rsa4096 => 4096,
}
}
}

pub enum Elliptic {
NISTP256,
}

pub enum KeyType {
Rsa(Rsa),
Elliptic(Elliptic),
}

pub struct CertificateParams<'a> {
pub common_name: &'a str,
pub issuer_name: &'a str,
Expand All @@ -48,22 +58,40 @@ pub struct CertificateParams<'a> {

impl IdentityProvider {
pub fn generate_new(key_type: KeyType) -> Result<Self, Error> {
let rsa = Rsa::generate(key_type.bit_length())?;
let private_key = pkey::PKey::from_rsa(rsa)?;
let private_key = match key_type {
KeyType::Rsa(rsa) => {
let bit_length = rsa.bit_length();
let rsa = openssl::rsa::Rsa::generate(bit_length)?;
pkey::PKey::from_rsa(rsa)?
}
KeyType::Elliptic(ecc) => {
let nid = match ecc {
Elliptic::NISTP256 => Nid::X9_62_PRIME256V1,
};
let group = EcGroup::from_curve_name(nid)?;
let private_key: EcKey<Private> = EcKey::generate(&group)?;
pkey::PKey::from_ec_key(private_key)?
}
};

Ok(IdentityProvider { private_key })
}

pub fn from_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = Rsa::private_key_from_der(der_bytes)?;
pub fn from_rsa_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = openssl::rsa::Rsa::private_key_from_der(der_bytes)?;
let private_key = pkey::PKey::from_rsa(rsa)?;

Ok(IdentityProvider { private_key })
}

pub fn export_private_key_der(&self) -> Result<Vec<u8>, Error> {
let rsa: Rsa<Private> = self.private_key.rsa()?;
Ok(rsa.private_key_to_der()?)
if let Ok(ec_key) = self.private_key.ec_key() {
Ok(ec_key.private_key_to_der()?)
} else if let Ok(rsa) = self.private_key.rsa() {
Ok(rsa.private_key_to_der()?)
} else {
Err(Error::UnexpectedError)?
}
}

pub fn create_certificate(&self, params: &CertificateParams) -> Result<Vec<u8>, Error> {
Expand Down
4 changes: 2 additions & 2 deletions src/idp/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn test_extract_sp() {
#[test]
fn test_signed_response() {
// init our IdP
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down Expand Up @@ -135,7 +135,7 @@ fn test_signed_response_threads() {

#[test]
fn test_signed_response_fingerprint() {
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down
63 changes: 45 additions & 18 deletions src/metadata/entity_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use chrono::prelude::*;
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::collections::VecDeque;
use std::io::Cursor;
use std::str::FromStr;
use thiserror::Error;
Expand All @@ -29,18 +30,8 @@ pub enum EntityDescriptorType {
}

impl EntityDescriptorType {
pub fn take_first(self) -> Option<EntityDescriptor> {
match self {
EntityDescriptorType::EntitiesDescriptor(descriptor) => descriptor
.descriptors
.into_iter()
.next()
.and_then(|descriptor_type| match descriptor_type {
EntityDescriptorType::EntitiesDescriptor(_) => None,
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
}),
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
}
pub fn iter(&self) -> EntityDescriptorIterator {
EntityDescriptorIterator::new(self)
}
}

Expand Down Expand Up @@ -284,6 +275,39 @@ impl TryFrom<&EntityDescriptor> for Event<'_> {
}
}

#[derive(Clone)]
pub struct EntityDescriptorIterator<'a> {
queue: VecDeque<&'a EntityDescriptorType>,
}

impl<'a> EntityDescriptorIterator<'a> {
pub fn new(root: &'a EntityDescriptorType) -> Self {
let mut queue = VecDeque::new();
queue.push_back(root);
EntityDescriptorIterator { queue }
}
}

impl<'a> Iterator for EntityDescriptorIterator<'a> {
type Item = &'a EntityDescriptor;

fn next(&mut self) -> Option<Self::Item> {
while let Some(current) = self.queue.pop_front() {
match current {
EntityDescriptorType::EntitiesDescriptor(entities_descriptor) => {
for descriptor in &entities_descriptor.descriptors {
self.queue.push_back(descriptor);
}
}
EntityDescriptorType::EntityDescriptor(entity_descriptor) => {
return Some(entity_descriptor);
}
}
}
None
}
}

#[cfg(test)]
mod test {
use crate::traits::ToXml;
Expand Down Expand Up @@ -345,6 +369,7 @@ mod test {
.parse()
.expect("Failed to parse EntitiesDescriptor");

assert_eq!(2, reparsed_entities_descriptor.descriptors.len());
assert_eq!(reparsed_entities_descriptor, entities_descriptor);
}

Expand All @@ -369,11 +394,12 @@ mod test {
let expected_entity_descriptor: EntityDescriptor = input_xml
.parse()
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
let entity_descriptor: EntityDescriptor = entity_descriptor_type
.take_first()
let entity_descriptor = entity_descriptor_type
.iter()
.next()
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");

assert_eq!(expected_entity_descriptor, entity_descriptor);
assert_eq!(&expected_entity_descriptor, entity_descriptor);
}

#[test]
Expand Down Expand Up @@ -401,11 +427,12 @@ mod test {
let expected_entity_descriptor: EntityDescriptor = input_xml
.parse()
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
let entity_descriptor: EntityDescriptor = entity_descriptor_type
.take_first()
let entity_descriptor = entity_descriptor_type
.iter()
.next()
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");
println!("{entity_descriptor:#?}");

assert_eq!(expected_entity_descriptor, entity_descriptor);
assert_eq!(&expected_entity_descriptor, entity_descriptor);
}
}
Loading
Loading