diff --git a/src/s3/error.rs b/src/s3/error.rs index bebbba0..52d1490 100644 --- a/src/s3/error.rs +++ b/src/s3/error.rs @@ -52,8 +52,8 @@ impl ErrorResponse { } } -#[derive(Debug)] /// Error definitions +#[derive(Debug)] pub enum Error { TimeParseError(chrono::ParseError), InvalidUrl(http::uri::InvalidUri), diff --git a/src/s3/response/list_objects.rs b/src/s3/response/list_objects.rs index 24b2281..2b41572 100644 --- a/src/s3/response/list_objects.rs +++ b/src/s3/response/list_objects.rs @@ -17,12 +17,14 @@ use std::collections::HashMap; use async_trait::async_trait; use bytes::Buf; use reqwest::header::HeaderMap; -use xmltree::Element; use crate::s3::{ error::Error, types::{FromS3Response, ListEntry, S3Request}, - utils::{from_iso8601utc, get_default_text, get_option_text, get_text, urldecode}, + utils::{ + from_iso8601utc, urldecode, + xml::{Element, MergeXmlElements}, + }, }; fn url_decode( @@ -57,65 +59,73 @@ fn parse_common_list_objects_response( ), Error, > { - let encoding_type = get_option_text(root, "EncodingType"); - let prefix = url_decode(&encoding_type, Some(get_default_text(root, "Prefix")))?; + let encoding_type = root.get_child_text("EncodingType"); + let prefix = url_decode( + &encoding_type, + Some(root.get_child_text("Prefix").unwrap_or_default()), + )?; Ok(( - get_text(root, "Name")?, + root.get_child_text_or_error("Name")?, encoding_type, prefix, - get_option_text(root, "Delimiter"), - match get_option_text(root, "IsTruncated") { - Some(v) => v.to_lowercase() == "true", - None => false, - }, - match get_option_text(root, "MaxKeys") { - Some(v) => Some(v.parse::()?), - None => None, - }, + root.get_child_text("Delimiter"), + root.get_child_text("IsTruncated") + .map(|x| x.to_lowercase() == "true") + .unwrap_or(false), + root.get_child_text("MaxKeys") + .map(|x| x.parse::()) + .transpose()?, )) } fn parse_list_objects_contents( contents: &mut Vec, - root: &mut xmltree::Element, - tag: &str, + root: &Element, + main_tag: &str, encoding_type: &Option, - is_delete_marker: bool, + with_delete_marker: bool, ) -> Result<(), Error> { - while let Some(v) = root.take_child(tag) { - let content = v; + let children1 = root.get_matching_children(main_tag); + let children2 = if with_delete_marker { + root.get_matching_children("DeleteMarker") + } else { + vec![] + }; + let mut merged = MergeXmlElements::new(&children1, &children2); + while let Some(content) = merged.next() { let etype = encoding_type.as_ref().cloned(); - let key = url_decode(&etype, Some(get_text(&content, "Key")?))?.unwrap(); - let last_modified = Some(from_iso8601utc(&get_text(&content, "LastModified")?)?); - let etag = get_option_text(&content, "ETag"); - let v = get_default_text(&content, "Size"); - let size = match v.is_empty() { - true => None, - false => Some(v.parse::()?), - }; - let storage_class = get_option_text(&content, "StorageClass"); - let is_latest = get_default_text(&content, "IsLatest").to_lowercase() == "true"; - let version_id = get_option_text(&content, "VersionId"); - let (owner_id, owner_name) = match content.get_child("Owner") { - Some(v) => (get_option_text(v, "ID"), get_option_text(v, "DisplayName")), - None => (None, None), - }; - let user_metadata = match content.get_child("UserMetadata") { - Some(v) => { - let mut map: HashMap = HashMap::new(); - for xml_node in &v.children { - let u = xml_node - .as_element() - .ok_or(Error::XmlError("unable to convert to element".to_string()))?; - map.insert( - u.name.to_string(), - u.get_text().unwrap_or_default().to_string(), - ); - } - Some(map) - } - None => None, - }; + let key = url_decode(&etype, Some(content.get_child_text_or_error("Key")?))?.unwrap(); + let last_modified = Some(from_iso8601utc( + &content.get_child_text_or_error("LastModified")?, + )?); + let etag = content.get_child_text("ETag"); + let size: Option = content + .get_child_text("Size") + .map(|x| x.parse::()) + .transpose()?; + let storage_class = content.get_child_text("StorageClass"); + let is_latest = content + .get_child_text("IsLatest") + .unwrap_or_default() + .to_lowercase() + == "true"; + let version_id = content.get_child_text("VersionId"); + let (owner_id, owner_name) = content + .get_child("Owner") + .map(|v| (v.get_child_text("ID"), v.get_child_text("DisplayName"))) + .unwrap_or((None, None)); + let user_metadata = content.get_child("UserMetadata").map(|v| { + v.get_xmltree_children() + .into_iter() + .map(|elem| { + ( + elem.name.to_string(), + elem.get_text().unwrap_or_default().to_string(), + ) + }) + .collect::>() + }); + let is_delete_marker = content.name() == "DeleteMarker"; contents.push(ListEntry { name: key, @@ -139,13 +149,16 @@ fn parse_list_objects_contents( fn parse_list_objects_common_prefixes( contents: &mut Vec, - root: &mut Element, + root: &Element, encoding_type: &Option, ) -> Result<(), Error> { - while let Some(v) = root.take_child("CommonPrefixes") { - let common_prefix = v; + for (_, common_prefix) in root.get_matching_children("CommonPrefixes") { contents.push(ListEntry { - name: url_decode(encoding_type, Some(get_text(&common_prefix, "Prefix")?))?.unwrap(), + name: url_decode( + encoding_type, + Some(common_prefix.get_child_text_or_error("Prefix")?), + )? + .unwrap(), last_modified: None, etag: None, owner_id: None, @@ -187,18 +200,19 @@ impl FromS3Response for ListObjectsV1Response { ) -> Result { let headers = resp.headers().clone(); let body = resp.bytes().await?; - let mut root = Element::parse(body.reader())?; + let xmltree_root = xmltree::Element::parse(body.reader())?; + let root = Element::from(&xmltree_root); let (name, encoding_type, prefix, delimiter, is_truncated, max_keys) = parse_common_list_objects_response(&root)?; - let marker = url_decode(&encoding_type, get_option_text(&root, "Marker"))?; - let mut next_marker = url_decode(&encoding_type, get_option_text(&root, "NextMarker"))?; + let marker = url_decode(&encoding_type, root.get_child_text("Marker"))?; + let mut next_marker = url_decode(&encoding_type, root.get_child_text("NextMarker"))?; let mut contents: Vec = Vec::new(); - parse_list_objects_contents(&mut contents, &mut root, "Contents", &encoding_type, false)?; + parse_list_objects_contents(&mut contents, &root, "Contents", &encoding_type, false)?; if is_truncated && next_marker.is_none() { next_marker = contents.last().map(|v| v.name.clone()) } - parse_list_objects_common_prefixes(&mut contents, &mut root, &encoding_type)?; + parse_list_objects_common_prefixes(&mut contents, &root, &encoding_type)?; Ok(ListObjectsV1Response { headers, @@ -240,24 +254,21 @@ impl FromS3Response for ListObjectsV2Response { ) -> Result { let headers = resp.headers().clone(); let body = resp.bytes().await?; - let mut root = Element::parse(body.reader())?; + let xmltree_root = xmltree::Element::parse(body.reader())?; + let root = Element::from(&xmltree_root); let (name, encoding_type, prefix, delimiter, is_truncated, max_keys) = parse_common_list_objects_response(&root)?; - let text = get_option_text(&root, "KeyCount"); - let key_count = match text { - Some(v) => match v.is_empty() { - true => None, - false => Some(v.parse::()?), - }, - None => None, - }; - let start_after = url_decode(&encoding_type, get_option_text(&root, "StartAfter"))?; - let continuation_token = get_option_text(&root, "ContinuationToken"); - let next_continuation_token = get_option_text(&root, "NextContinuationToken"); + let key_count = root + .get_child_text("KeyCount") + .map(|x| x.parse::()) + .transpose()?; + let start_after = url_decode(&encoding_type, root.get_child_text("StartAfter"))?; + let continuation_token = root.get_child_text("ContinuationToken"); + let next_continuation_token = root.get_child_text("NextContinuationToken"); let mut contents: Vec = Vec::new(); - parse_list_objects_contents(&mut contents, &mut root, "Contents", &encoding_type, false)?; - parse_list_objects_common_prefixes(&mut contents, &mut root, &encoding_type)?; + parse_list_objects_contents(&mut contents, &root, "Contents", &encoding_type, false)?; + parse_list_objects_common_prefixes(&mut contents, &root, &encoding_type)?; Ok(ListObjectsV2Response { headers, @@ -301,24 +312,18 @@ impl FromS3Response for ListObjectVersionsResponse { ) -> Result { let headers = resp.headers().clone(); let body = resp.bytes().await?; - let mut root = Element::parse(body.reader())?; + let xmltree_root = xmltree::Element::parse(body.reader())?; + let root = Element::from(&xmltree_root); let (name, encoding_type, prefix, delimiter, is_truncated, max_keys) = parse_common_list_objects_response(&root)?; - let key_marker = url_decode(&encoding_type, get_option_text(&root, "KeyMarker"))?; - let next_key_marker = url_decode(&encoding_type, get_option_text(&root, "NextKeyMarker"))?; - let version_id_marker = get_option_text(&root, "VersionIdMarker"); - let next_version_id_marker = get_option_text(&root, "NextVersionIdMarker"); + let key_marker = url_decode(&encoding_type, root.get_child_text("KeyMarker"))?; + let next_key_marker = url_decode(&encoding_type, root.get_child_text("NextKeyMarker"))?; + let version_id_marker = root.get_child_text("VersionIdMarker"); + let next_version_id_marker = root.get_child_text("NextVersionIdMarker"); let mut contents: Vec = Vec::new(); - parse_list_objects_contents(&mut contents, &mut root, "Version", &encoding_type, false)?; - parse_list_objects_common_prefixes(&mut contents, &mut root, &encoding_type)?; - parse_list_objects_contents( - &mut contents, - &mut root, - "DeleteMarker", - &encoding_type, - true, - )?; + parse_list_objects_contents(&mut contents, &root, "Version", &encoding_type, true)?; + parse_list_objects_common_prefixes(&mut contents, &root, &encoding_type)?; Ok(ListObjectVersionsResponse { headers, diff --git a/src/s3/utils.rs b/src/s3/utils.rs index 088b28d..bcaf73c 100644 --- a/src/s3/utils.rs +++ b/src/s3/utils.rs @@ -391,3 +391,157 @@ pub fn copy_slice(dst: &mut [u8], src: &[u8]) -> usize { } c } + +pub mod xml { + use std::collections::HashMap; + + use crate::s3::error::Error; + + struct XmlElementIndex { + children: HashMap>, + } + + impl XmlElementIndex { + fn get_first(&self, tag: &str) -> Option { + let tag: String = tag.to_string(); + let is = self.children.get(&tag)?; + is.first().map(|v| *v) + } + + fn get(&self, tag: &str) -> Option<&Vec> { + let tag: String = tag.to_string(); + self.children.get(&tag) + } + } + + impl From<&xmltree::Element> for XmlElementIndex { + fn from(value: &xmltree::Element) -> Self { + let mut children = HashMap::new(); + for (i, e) in value + .children + .iter() + .enumerate() + .filter_map(|(i, v)| v.as_element().map(|e| (i, e))) + { + children + .entry(e.name.clone()) + .or_insert_with(Vec::new) + .push(i); + } + Self { children } + } + } + + pub struct Element<'a> { + inner: &'a xmltree::Element, + child_element_index: XmlElementIndex, + } + + impl<'a> From<&'a xmltree::Element> for Element<'a> { + fn from(value: &'a xmltree::Element) -> Self { + let element_index = XmlElementIndex::from(value); + Self { + inner: value, + child_element_index: element_index, + } + } + } + + impl Element<'_> { + pub fn name(&self) -> &str { + self.inner.name.as_str() + } + + pub fn get_child_text(&self, tag: &str) -> Option { + let index = self.child_element_index.get_first(tag)?; + self.inner.children[index] + .as_element()? + .get_text() + .map(|v| v.to_string()) + } + + pub fn get_child_text_or_error(&self, tag: &str) -> Result { + let i = self + .child_element_index + .get_first(tag) + .ok_or(Error::XmlError(format!("<{}> tag not found", tag)))?; + self.inner.children[i] + .as_element() + .unwrap() + .get_text() + .map(|x| x.to_string()) + .ok_or(Error::XmlError(format!("text of <{}> tag not found", tag))) + } + + // Returns all children with given tag along with their index. + pub fn get_matching_children(&self, tag: &str) -> Vec<(usize, Element)> { + self.child_element_index + .get(tag) + .unwrap_or(&vec![]) + .into_iter() + .map(|i| (*i, self.inner.children[*i].as_element().unwrap().into())) + .collect() + } + + pub fn get_child(&self, tag: &str) -> Option { + let index = self.child_element_index.get_first(tag)?; + Some(self.inner.children[index].as_element()?.into()) + } + + pub fn get_xmltree_children(&self) -> Vec<&xmltree::Element> { + self.inner + .children + .iter() + .filter_map(|v| v.as_element()) + .collect() + } + } + + // Helper type that implements merge sort in the iterator. + pub struct MergeXmlElements<'a> { + v1: &'a Vec<(usize, Element<'a>)>, + v2: &'a Vec<(usize, Element<'a>)>, + i1: usize, + i2: usize, + } + + impl<'a> MergeXmlElements<'a> { + pub fn new(v1: &'a Vec<(usize, Element<'a>)>, v2: &'a Vec<(usize, Element<'a>)>) -> Self { + Self { + v1, + v2, + i1: 0, + i2: 0, + } + } + } + + impl<'a> Iterator for MergeXmlElements<'a> { + type Item = &'a Element<'a>; + + fn next(&mut self) -> Option { + let c1 = self.v1.get(self.i1); + let c2 = self.v2.get(self.i2); + match (c1, c2) { + (Some(val1), Some(val2)) => { + if val1.0 < val2.0 { + self.i1 += 1; + Some(&val1.1) + } else { + self.i2 += 1; + Some(&val2.1) + } + } + (Some(val1), None) => { + self.i1 += 1; + Some(&val1.1) + } + (None, Some(val2)) => { + self.i2 += 1; + Some(&val2.1) + } + (None, None) => None, + } + } + } +}