diff --git a/vegafusion-core/src/planning/projection_pushdown.rs b/vegafusion-core/src/planning/projection_pushdown.rs index 1eaf222c..10ea8ef5 100644 --- a/vegafusion-core/src/planning/projection_pushdown.rs +++ b/vegafusion-core/src/planning/projection_pushdown.rs @@ -18,8 +18,9 @@ use crate::spec::transform::project::ProjectTransformSpec; use crate::spec::transform::{TransformColumns, TransformSpec}; use crate::task_graph::graph::ScopedVariable; use crate::task_graph::scope::TaskScope; -use itertools::sorted; +use itertools::{sorted, Itertools}; use petgraph::algo::toposort; +use std::collections::HashMap; use vegafusion_common::arrow::array::StringArray; use vegafusion_common::data::table::VegaFusionTable; use vegafusion_common::error::Result; @@ -51,6 +52,88 @@ pub fn projection_pushdown(chart_spec: &mut ChartSpec) -> Result<()> { Ok(()) } +/// Get column usage info for the top-level root datasets of a Vega spec +/// Returns map from dataset name to either: +/// - None if column usage could not be determined +/// - Vec of the referenced columns if column usage could be determined precisely +pub fn get_column_usage(chart_spec: &ChartSpec) -> Result>>> { + let mut chart_spec = chart_spec.clone(); + + // split root nodes that have transforms so that the usage we compute refers to the source + // data, not the result after transforms + let mut new_data_specs: Vec = Vec::new(); + let suffix = "__column_usage_root"; + for data_spec in &mut chart_spec.data { + if data_spec.source.is_none() && !data_spec.transform.is_empty() { + // This is a root dataset that has transforms, so we split it + let name = data_spec.name.clone(); + let mut transforms = Vec::new(); + transforms.append(&mut data_spec.transform); + let root_name = format!("{name}{suffix}"); + data_spec.name = root_name.clone(); + + let new_spec = DataSpec { + name: name.clone(), + source: Some(root_name), + transform: transforms, + ..Default::default() + }; + + new_data_specs.push(new_spec); + } + } + + chart_spec.data.append(&mut new_data_specs); + + let datum_var = None; + let usage_scope = Vec::new(); + let task_scope = chart_spec.to_task_scope()?; + + // Collect field usage for vlSelectionTest datasets + let mut vl_selection_visitor = CollectVlSelectionTestFieldsVisitor::new(task_scope.clone()); + chart_spec.walk(&mut vl_selection_visitor)?; + let vl_selection_fields = vl_selection_visitor.vl_selection_fields; + + let datasets_column_usage = chart_spec.datasets_column_usage( + &datum_var, + usage_scope.as_slice(), + &task_scope, + &vl_selection_fields, + ); + + let mut root_dataset_columns: HashMap>> = HashMap::new(); + for data_spec in &chart_spec.data { + if data_spec.source.is_none() { + let var = Variable::new(VariableNamespace::Data, &data_spec.name); + let scoped_var = (var, Vec::new()); + let column_usage = datasets_column_usage + .usages + .get(&scoped_var) + .unwrap_or(&ColumnUsage::Unknown); + + // Remove root dataset suffix that was added above + let original_name = data_spec + .name + .strip_suffix(suffix) + .unwrap_or(&data_spec.name) + .to_string(); + + match column_usage { + ColumnUsage::Unknown => { + root_dataset_columns.insert(original_name.clone(), None); + } + ColumnUsage::Known(used) => { + root_dataset_columns.insert( + original_name.clone(), + Some(used.iter().cloned().sorted().collect()), + ); + } + } + } + } + Ok(root_dataset_columns) +} + impl GetDatasetsColumnUsage for MarkEncodingField { fn datasets_column_usage( &self, diff --git a/vegafusion-core/src/spec/data.rs b/vegafusion-core/src/spec/data.rs index c3fa0f81..56a1c92d 100644 --- a/vegafusion-core/src/spec/data.rs +++ b/vegafusion-core/src/spec/data.rs @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet}; use vegafusion_common::data::table::VegaFusionTable; use vegafusion_common::error::VegaFusionError; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] pub struct DataSpec { pub name: String, diff --git a/vegafusion-runtime/tests/test_projection_pushdown.rs b/vegafusion-runtime/tests/test_projection_pushdown.rs index 886cfa46..c5eb6b08 100644 --- a/vegafusion-runtime/tests/test_projection_pushdown.rs +++ b/vegafusion-runtime/tests/test_projection_pushdown.rs @@ -6,9 +6,11 @@ fn crate_dir() -> String { #[cfg(test)] mod test_custom_specs { use crate::crate_dir; + use itertools::Itertools; use rstest::rstest; use std::fs; use vegafusion_core::planning::plan::{PlannerConfig, SpecPlan}; + use vegafusion_core::planning::projection_pushdown::get_column_usage; use vegafusion_core::spec::chart::ChartSpec; use vegafusion_core::spec::transform::TransformSpec; @@ -32,7 +34,7 @@ mod test_custom_specs { "bin_maxbins_60_IMDB Rating_end" ]), )] - fn test(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) { + fn test_proj_pushdown(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) { // Load spec let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name); let spec_str = fs::read_to_string(spec_path).unwrap(); @@ -49,12 +51,40 @@ mod test_custom_specs { // Print data // println!("{}", serde_json::to_string_pretty(&spec_plan.server_spec.data).unwrap()); + let expected_fields: Vec<_> = projection_fields + .into_iter() + .map(String::from) + .sorted() + .collect(); if let TransformSpec::Project(project) = tx { - let expected_fields: Vec<_> = projection_fields.iter().map(|f| f.to_string()).collect(); assert_eq!(project.fields, expected_fields); } else { panic!("Expected project transform") } } + + # [rstest( + spec_name, + projection_fields, + case("vegalite/point_2d", vec!["Horsepower", "Miles_per_Gallon"]), + case("vegalite/point_bubble", vec!["Acceleration", "Horsepower", "Miles_per_Gallon"]), + case("vegalite/concat_marginal_histograms", vec![ + "IMDB Rating", + "Rotten Tomatoes Rating", + ]), + case("vegalite/rect_binned_heatmap", vec![ + "IMDB Rating", + "Rotten Tomatoes Rating", + ]), + )] + fn test_get_column_usage(spec_name: &str, projection_fields: Vec<&str>) { + // Load spec + let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name); + let spec_str = fs::read_to_string(spec_path).unwrap(); + let spec: ChartSpec = serde_json::from_str(&spec_str).unwrap(); + let root_usages = get_column_usage(&spec).unwrap(); + let expected_fields: Vec<_> = projection_fields.into_iter().map(String::from).collect(); + assert_eq!(root_usages["source_0"], Some(expected_fields)); + } }