Skip to content

Commit

Permalink
Add get_column_usage utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease committed Oct 9, 2024
1 parent 151163c commit 7956756
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
85 changes: 84 additions & 1 deletion vegafusion-core/src/planning/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ use crate::spec::transform::project::ProjectTransformSpec;
use crate::spec::transform::{TransformColumns, TransformSpec};
use crate::task_graph::graph::ScopedVariable;
use crate::task_graph::scope::TaskScope;
use itertools::sorted;
use itertools::{sorted, Itertools};
use petgraph::algo::toposort;
use std::collections::HashMap;
use vegafusion_common::arrow::array::StringArray;
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_common::error::Result;
Expand Down Expand Up @@ -51,6 +52,88 @@ pub fn projection_pushdown(chart_spec: &mut ChartSpec) -> Result<()> {
Ok(())
}

/// Get column usage info for the top-level root datasets of a Vega spec
/// Returns map from dataset name to either:
/// - None if column usage could not be determined
/// - Vec<String> of the referenced columns if column usage could be determined precisely
pub fn get_column_usage(chart_spec: &ChartSpec) -> Result<HashMap<String, Option<Vec<String>>>> {
let mut chart_spec = chart_spec.clone();

// split root nodes that have transforms so that the usage we compute refers to the source
// data, not the result after transforms
let mut new_data_specs: Vec<DataSpec> = Vec::new();
let suffix = "__column_usage_root";
for data_spec in &mut chart_spec.data {
if data_spec.source.is_none() && !data_spec.transform.is_empty() {
// This is a root dataset that has transforms, so we split it
let name = data_spec.name.clone();
let mut transforms = Vec::new();
transforms.append(&mut data_spec.transform);
let root_name = format!("{name}{suffix}");
data_spec.name = root_name.clone();

let new_spec = DataSpec {
name: name.clone(),
source: Some(root_name),
transform: transforms,
..Default::default()
};

new_data_specs.push(new_spec);
}
}

chart_spec.data.append(&mut new_data_specs);

let datum_var = None;
let usage_scope = Vec::new();
let task_scope = chart_spec.to_task_scope()?;

// Collect field usage for vlSelectionTest datasets
let mut vl_selection_visitor = CollectVlSelectionTestFieldsVisitor::new(task_scope.clone());
chart_spec.walk(&mut vl_selection_visitor)?;
let vl_selection_fields = vl_selection_visitor.vl_selection_fields;

let datasets_column_usage = chart_spec.datasets_column_usage(
&datum_var,
usage_scope.as_slice(),
&task_scope,
&vl_selection_fields,
);

let mut root_dataset_columns: HashMap<String, Option<Vec<String>>> = HashMap::new();
for data_spec in &chart_spec.data {
if data_spec.source.is_none() {
let var = Variable::new(VariableNamespace::Data, &data_spec.name);
let scoped_var = (var, Vec::new());
let column_usage = datasets_column_usage
.usages
.get(&scoped_var)
.unwrap_or(&ColumnUsage::Unknown);

// Remove root dataset suffix that was added above
let original_name = data_spec
.name
.strip_suffix(suffix)
.unwrap_or(&data_spec.name)
.to_string();

match column_usage {
ColumnUsage::Unknown => {
root_dataset_columns.insert(original_name.clone(), None);
}
ColumnUsage::Known(used) => {
root_dataset_columns.insert(
original_name.clone(),
Some(used.iter().cloned().sorted().collect()),
);
}
}
}
}
Ok(root_dataset_columns)
}

impl GetDatasetsColumnUsage for MarkEncodingField {
fn datasets_column_usage(
&self,
Expand Down
2 changes: 1 addition & 1 deletion vegafusion-core/src/spec/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet};
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_common::error::VegaFusionError;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct DataSpec {
pub name: String,

Expand Down
34 changes: 32 additions & 2 deletions vegafusion-runtime/tests/test_projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ fn crate_dir() -> String {
#[cfg(test)]
mod test_custom_specs {
use crate::crate_dir;
use itertools::Itertools;
use rstest::rstest;
use std::fs;
use vegafusion_core::planning::plan::{PlannerConfig, SpecPlan};
use vegafusion_core::planning::projection_pushdown::get_column_usage;
use vegafusion_core::spec::chart::ChartSpec;
use vegafusion_core::spec::transform::TransformSpec;

Expand All @@ -32,7 +34,7 @@ mod test_custom_specs {
"bin_maxbins_60_IMDB Rating_end"
]),
)]
fn test(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) {
fn test_proj_pushdown(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) {
// Load spec
let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name);
let spec_str = fs::read_to_string(spec_path).unwrap();
Expand All @@ -49,12 +51,40 @@ mod test_custom_specs {

// Print data
// println!("{}", serde_json::to_string_pretty(&spec_plan.server_spec.data).unwrap());
let expected_fields: Vec<_> = projection_fields
.into_iter()
.map(String::from)
.sorted()
.collect();

if let TransformSpec::Project(project) = tx {
let expected_fields: Vec<_> = projection_fields.iter().map(|f| f.to_string()).collect();
assert_eq!(project.fields, expected_fields);
} else {
panic!("Expected project transform")
}
}

# [rstest(
spec_name,
projection_fields,
case("vegalite/point_2d", vec!["Horsepower", "Miles_per_Gallon"]),
case("vegalite/point_bubble", vec!["Acceleration", "Horsepower", "Miles_per_Gallon"]),
case("vegalite/concat_marginal_histograms", vec![
"IMDB Rating",
"Rotten Tomatoes Rating",
]),
case("vegalite/rect_binned_heatmap", vec![
"IMDB Rating",
"Rotten Tomatoes Rating",
]),
)]
fn test_get_column_usage(spec_name: &str, projection_fields: Vec<&str>) {
// Load spec
let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name);
let spec_str = fs::read_to_string(spec_path).unwrap();
let spec: ChartSpec = serde_json::from_str(&spec_str).unwrap();
let root_usages = get_column_usage(&spec).unwrap();
let expected_fields: Vec<_> = projection_fields.into_iter().map(String::from).collect();
assert_eq!(root_usages["source_0"], Some(expected_fields));
}
}

0 comments on commit 7956756

Please sign in to comment.