diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 446190e..5a30181 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -17,10 +17,57 @@ use crate::signature::Signature; use chrono::prelude::*; use serde::Deserialize; -use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; +use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event}; use quick_xml::Writer; use std::io::Cursor; +use std::str::FromStr; + +use thiserror::Error; + +#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub struct NameID { + #[serde(rename = "Format")] + pub format: Option, + + #[serde(rename = "$value")] + pub value: String, +} + +impl NameID { + fn name() -> &'static str { + "saml2:NameID" + } + + fn schema() -> &'static [(&'static str, &'static str)] { + &[("xmlns:saml2", "urn:oasis:names:tc:SAML:2.0:assertion")] + } +} + +impl TryFrom<&NameID> for Event<'_> { + type Error = Box; + + fn try_from(value: &NameID) -> Result { + let mut write_buf = Vec::new(); + let mut writer = Writer::new(Cursor::new(&mut write_buf)); + let mut root = BytesStart::new(NameID::name()); + + for attr in NameID::schema() { + root.push_attribute((attr.0, attr.1)); + } + + if let Some(format) = &value.format { + root.push_attribute(("Format", format.as_ref())); + } + + writer.write_event(Event::Start(root))?; + writer.write_event(Event::Text(BytesText::from_escaped(value.value.as_str())))?; + writer.write_event(Event::End(BytesEnd::new(NameID::name())))?; + Ok(Event::Text(BytesText::from_escaped(String::from_utf8( + write_buf, + )?))) + } +} #[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] pub struct LogoutRequest { @@ -38,6 +85,81 @@ pub struct LogoutRequest { pub signature: Option, #[serde(rename = "@SessionIndex")] pub session_index: Option, + #[serde(rename = "NameID")] + pub name_id: Option, +} + +#[derive(Debug, Error)] +pub enum LogoutRequestError { + #[error("Failed to deserialize LogoutRequest: {:?}", source)] + ParseError { + #[from] + source: quick_xml::DeError, + }, +} + +impl FromStr for LogoutRequest { + type Err = LogoutRequestError; + + fn from_str(s: &str) -> Result { + Ok(quick_xml::de::from_str(s)?) + } +} + +const LOGOUT_REQUEST_NAME: &str = "saml2p:LogoutRequest"; +const SESSION_INDEX_NAME: &str = "saml2p:SessionIndex"; +const PROTOCOL_SCHEMA: (&str, &str) = ("xmlns:saml2p", "urn:oasis:names:tc:SAML:2.0:protocol"); + +impl LogoutRequest { + pub fn to_xml(&self) -> Result> { + let mut write_buf = Vec::new(); + let mut writer = Writer::new(Cursor::new(&mut write_buf)); + writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?; + + let mut root = BytesStart::new(LOGOUT_REQUEST_NAME); + root.push_attribute(PROTOCOL_SCHEMA); + if let Some(id) = &self.id { + root.push_attribute(("ID", id.as_ref())); + } + if let Some(version) = &self.version { + root.push_attribute(("Version", version.as_ref())); + } + if let Some(issue_instant) = &self.issue_instant { + root.push_attribute(( + "IssueInstant", + issue_instant + .to_rfc3339_opts(SecondsFormat::Millis, true) + .as_ref(), + )); + } + if let Some(destination) = &self.destination { + root.push_attribute(("Destination", destination.as_ref())); + } + + writer.write_event(Event::Start(root))?; + + if let Some(issuer) = &self.issuer { + let event: Event<'_> = issuer.try_into()?; + writer.write_event(event)?; + } + if let Some(signature) = &self.signature { + let event: Event<'_> = signature.try_into()?; + writer.write_event(event)?; + } + + if let Some(session) = &self.session_index { + writer.write_event(Event::Start(BytesStart::new(SESSION_INDEX_NAME)))?; + writer.write_event(Event::Text(BytesText::new(session)))?; + writer.write_event(Event::End(BytesEnd::new(SESSION_INDEX_NAME)))?; + } + if let Some(name_id) = &self.name_id { + let event: Event<'_> = name_id.try_into()?; + writer.write_event(event)?; + } + + writer.write_event(Event::End(BytesEnd::new(LOGOUT_REQUEST_NAME)))?; + Ok(String::from_utf8(write_buf)?) + } } #[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] @@ -475,3 +597,120 @@ pub struct LogoutResponse { #[serde(rename = "Status")] pub status: Option, } + +#[derive(Debug, Error)] +pub enum LogoutResponseError { + #[error("Failed to deserialize LogoutResponse: {:?}", source)] + ParseError { + #[from] + source: quick_xml::DeError, + }, +} + +impl FromStr for LogoutResponse { + type Err = LogoutResponseError; + + fn from_str(s: &str) -> Result { + Ok(quick_xml::de::from_str(s)?) + } +} + +const LOGOUT_RESPONSE_NAME: &str = "saml2p:LogoutResponse"; + +impl LogoutResponse { + pub fn to_xml(&self) -> Result> { + let mut write_buf = Vec::new(); + let mut writer = Writer::new(Cursor::new(&mut write_buf)); + writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?; + + let mut root = BytesStart::new(LOGOUT_RESPONSE_NAME); + root.push_attribute(PROTOCOL_SCHEMA); + if let Some(id) = &self.id { + root.push_attribute(("ID", id.as_ref())); + } + if let Some(resp_to) = &self.in_response_to { + root.push_attribute(("InResponseTo", resp_to.as_ref())); + } + if let Some(version) = &self.version { + root.push_attribute(("Version", version.as_ref())); + } + if let Some(issue_instant) = &self.issue_instant { + root.push_attribute(( + "IssueInstant", + issue_instant + .to_rfc3339_opts(SecondsFormat::Millis, true) + .as_ref(), + )); + } + if let Some(destination) = &self.destination { + root.push_attribute(("Destination", destination.as_ref())); + } + if let Some(consent) = &self.consent { + root.push_attribute(("Consent", consent.as_ref())); + } + + writer.write_event(Event::Start(root))?; + + if let Some(issuer) = &self.issuer { + let event: Event<'_> = issuer.try_into()?; + writer.write_event(event)?; + } + if let Some(signature) = &self.signature { + let event: Event<'_> = signature.try_into()?; + writer.write_event(event)?; + } + + if let Some(status) = &self.status { + let event: Event<'_> = status.try_into()?; + writer.write_event(event)?; + } + + writer.write_event(Event::End(BytesEnd::new(LOGOUT_RESPONSE_NAME)))?; + Ok(String::from_utf8(write_buf)?) + } +} + +#[cfg(test)] +mod test { + use super::issuer::Issuer; + use super::{LogoutRequest, LogoutResponse, NameID, Status, StatusCode}; + use chrono::TimeZone; + + #[test] + fn test_deserialize_serialize_logout_request() { + let request_xml = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/test_vectors/logout_request.xml", + )); + let expected_request: LogoutRequest = request_xml + .parse() + .expect("failed to parse logout_request.xml"); + let serialized_request = expected_request + .to_xml() + .expect("failed to convert request to xml"); + let actual_request: LogoutRequest = serialized_request + .parse() + .expect("failed to re-parse request"); + + assert_eq!(expected_request, actual_request); + } + + #[test] + fn test_deserialize_serialize_logout_response() { + let response_xml = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/test_vectors/logout_response.xml", + )); + let expected_response: LogoutResponse = response_xml + .parse() + .expect("failed to parse logout_response.xml"); + let serialized_response = expected_response + .to_xml() + .expect("failed to convert Response to xml"); + let actual_response: LogoutResponse = serialized_response + .parse() + .expect("failed to re-parse Response"); + + assert_eq!(expected_response, actual_response); + } +} diff --git a/test_vectors/logout_request.xml b/test_vectors/logout_request.xml new file mode 100644 index 0000000..4c28380 --- /dev/null +++ b/test_vectors/logout_request.xml @@ -0,0 +1,6 @@ + + + http://sp.example.com/demo1/metadata.php + session-index-1 + test@example.com + \ No newline at end of file diff --git a/test_vectors/logout_response.xml b/test_vectors/logout_response.xml new file mode 100644 index 0000000..7405a51 --- /dev/null +++ b/test_vectors/logout_response.xml @@ -0,0 +1,7 @@ + + + http://sp.example.com/demo1/metadata.php + + + + \ No newline at end of file