diff --git a/src/stripe.rs b/src/stripe.rs index d43e9cee..3f043610 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::{collections::HashMap, io::Read, sync::Arc}; use bytes::Bytes; @@ -139,23 +140,23 @@ impl Stripe { .context(IoSnafu)?; let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); - let columns = projected_data_type + let columns: Vec = projected_data_type .children() .iter() .map(|col| Column::new(col.name(), col.data_type(), &footer)) .collect(); + let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - let kind = stream.kind(); - let data = Column::read_stream(reader, stream_offset, length)?; - - // TODO(weny): filter out unused streams. - stream_map.insert((column_id, kind), data); - + if column_ids.contains(&column_id) { + let kind = stream.kind(); + let data = Column::read_stream(reader, stream_offset, length)?; + stream_map.insert((column_id, kind), data); + } stream_offset += length; } @@ -192,22 +193,23 @@ impl Stripe { .context(IoSnafu)?; let footer = Arc::new(deserialize_stripe_footer(footer, compression)?); - let columns = projected_data_type + let columns: Vec = projected_data_type .children() .iter() .map(|col| Column::new(col.name(), col.data_type(), &footer)) .collect(); + let column_ids = collect_required_column_ids(&columns); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); for stream in &footer.streams { let length = stream.length(); let column_id = stream.column(); - let kind = stream.kind(); - let data = Column::read_stream_async(reader, stream_offset, length).await?; - - // TODO(weny): filter out unused streams. - stream_map.insert((column_id, kind), data); + if column_ids.contains(&column_id) { + let kind = stream.kind(); + let data = Column::read_stream_async(reader, stream_offset, length).await?; + stream_map.insert((column_id, kind), data); + } stream_offset += length; } @@ -282,3 +284,12 @@ fn deserialize_stripe_footer( .context(error::IoSnafu)?; StripeFooter::decode(buffer.as_slice()).context(error::DecodeProtoSnafu) } + +fn collect_required_column_ids(columns: &[Column]) -> HashSet { + let mut set = HashSet::new(); + for column in columns { + set.insert(column.column_id()); + set.extend(collect_required_column_ids(&column.children())); + } + set +}