Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_column_usage utility function #516

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ platforms = ["osx-arm64", "osx-64", "linux-64", "win-64"]
macos = "12.0"

[tasks]
fmt-rs = "cargo fmt --all"
check-rs-fmt = "cargo fmt --all -- --check"
check-rs-warnings = "export RUSTFLAGS=\"-D warnings\" && cargo check --tests"
check-rs-clippy = "cargo clippy -- -A clippy::borrow_deref_ref"
Expand Down
85 changes: 84 additions & 1 deletion vegafusion-core/src/planning/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ use crate::spec::transform::project::ProjectTransformSpec;
use crate::spec::transform::{TransformColumns, TransformSpec};
use crate::task_graph::graph::ScopedVariable;
use crate::task_graph::scope::TaskScope;
use itertools::sorted;
use itertools::{sorted, Itertools};
use petgraph::algo::toposort;
use std::collections::HashMap;
use vegafusion_common::arrow::array::StringArray;
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_common::error::Result;
Expand Down Expand Up @@ -51,6 +52,88 @@ pub fn projection_pushdown(chart_spec: &mut ChartSpec) -> Result<()> {
Ok(())
}

/// Get column usage info for the top-level root datasets of a Vega spec
/// Returns map from dataset name to either:
/// - None if column usage could not be determined
/// - Vec<String> of the referenced columns if column usage could be determined precisely
pub fn get_column_usage(chart_spec: &ChartSpec) -> Result<HashMap<String, Option<Vec<String>>>> {
let mut chart_spec = chart_spec.clone();

// split root nodes that have transforms so that the usage we compute refers to the source
// data, not the result after transforms
let mut new_data_specs: Vec<DataSpec> = Vec::new();
let suffix = "__column_usage_root";
for data_spec in &mut chart_spec.data {
if data_spec.source.is_none() && !data_spec.transform.is_empty() {
// This is a root dataset that has transforms, so we split it
let name = data_spec.name.clone();
let mut transforms = Vec::new();
transforms.append(&mut data_spec.transform);
let root_name = format!("{name}{suffix}");
data_spec.name = root_name.clone();

let new_spec = DataSpec {
name: name.clone(),
source: Some(root_name),
transform: transforms,
..Default::default()
};

new_data_specs.push(new_spec);
}
}

chart_spec.data.append(&mut new_data_specs);

let datum_var = None;
let usage_scope = Vec::new();
let task_scope = chart_spec.to_task_scope()?;

// Collect field usage for vlSelectionTest datasets
let mut vl_selection_visitor = CollectVlSelectionTestFieldsVisitor::new(task_scope.clone());
chart_spec.walk(&mut vl_selection_visitor)?;
let vl_selection_fields = vl_selection_visitor.vl_selection_fields;

let datasets_column_usage = chart_spec.datasets_column_usage(
&datum_var,
usage_scope.as_slice(),
&task_scope,
&vl_selection_fields,
);

let mut root_dataset_columns: HashMap<String, Option<Vec<String>>> = HashMap::new();
for data_spec in &chart_spec.data {
if data_spec.source.is_none() {
let var = Variable::new(VariableNamespace::Data, &data_spec.name);
let scoped_var = (var, Vec::new());
let column_usage = datasets_column_usage
.usages
.get(&scoped_var)
.unwrap_or(&ColumnUsage::Unknown);

// Remove root dataset suffix that was added above
let original_name = data_spec
.name
.strip_suffix(suffix)
.unwrap_or(&data_spec.name)
.to_string();

match column_usage {
ColumnUsage::Unknown => {
root_dataset_columns.insert(original_name.clone(), None);
}
ColumnUsage::Known(used) => {
root_dataset_columns.insert(
original_name.clone(),
Some(used.iter().cloned().sorted().collect()),
);
}
}
}
}
Ok(root_dataset_columns)
}

impl GetDatasetsColumnUsage for MarkEncodingField {
fn datasets_column_usage(
&self,
Expand Down
2 changes: 1 addition & 1 deletion vegafusion-core/src/spec/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::collections::{HashMap, HashSet};
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_common::error::VegaFusionError;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct DataSpec {
pub name: String,

Expand Down
11 changes: 11 additions & 0 deletions vegafusion-python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod connection;

use pyo3;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyTuple};
Expand All @@ -19,6 +20,7 @@ use serde_json::json;
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_core::patch::patch_pre_transformed_spec;
use vegafusion_core::planning::plan::{PlannerConfig, PreTransformSpecWarningSpec, SpecPlan};
use vegafusion_core::planning::projection_pushdown::get_column_usage as rs_get_column_usage;
use vegafusion_core::planning::watch::{ExportUpdateJSON, WatchPlan};
use vegafusion_core::proto::gen::tasks::{TzConfig, Variable};
use vegafusion_core::spec::chart::ChartSpec;
Expand Down Expand Up @@ -614,6 +616,14 @@ impl PyVegaFusionRuntime {
}
}

#[pyfunction]
#[pyo3(signature = (spec))]
pub fn get_column_usage(py: Python, spec: PyObject) -> PyResult<PyObject> {
let spec = parse_json_spec(spec)?;
let usage = rs_get_column_usage(&spec)?;
Ok(pythonize::pythonize(py, &usage)?.into())
}

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
Expand All @@ -622,6 +632,7 @@ fn _vegafusion(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<PyVegaFusionRuntime>()?;
m.add_class::<PySqlConnection>()?;
m.add_class::<PyChartState>()?;
m.add_function(wrap_pyfunction!(get_column_usage, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}
Expand Down
17 changes: 17 additions & 0 deletions vegafusion-python/tests/test_get_column_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
from pathlib import Path


import vegafusion as vf

here = Path(__file__).parent

spec_dir = here / ".." / ".." / "vegafusion-runtime" / "tests" / "specs"


def test_get_column_usage():
spec_file = spec_dir / "vegalite" / "concat_marginal_histograms.vg.json"
spec = json.loads(spec_file.read_text("utf8"))
usages = vf.get_column_usage(spec)

assert usages == {"source_0": ["IMDB Rating", "Rotten Tomatoes Rating"]}
3 changes: 3 additions & 0 deletions vegafusion-python/vegafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._vegafusion import __version__
from .local_tz import get_local_tz, set_local_tz
from .runtime import runtime
from .utils import get_column_usage


def patched_version(distribution_name: str) -> str:
Expand All @@ -20,8 +21,10 @@ def patched_version(distribution_name: str) -> str:
# Patch importlib.metadata.version to handle our dummy package
importlib.metadata.version = patched_version


__all__ = [
"runtime",
"set_local_tz",
"get_local_tz",
"get_column_usage",
]
24 changes: 24 additions & 0 deletions vegafusion-python/vegafusion/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Any, cast

from ._vegafusion import get_column_usage as _get_column_usage


def get_column_usage(spec: dict[str, Any]) -> dict[str, list[str] | None]:
"""
Compute the columns from each root dataset that are referenced in a
Vega spec.

Args:
spec: Vega spec

Returns:
dict[str, list[str] | None]: Dict from root-level dataset name
to either:
- A list of columns that are referenced in this dataset if this can
be determined precisely
- None if it was not possible to determine the full set of columns
that are referenced from this dataset
"""
return cast("dict[str, list[str] | None]", _get_column_usage(spec))
34 changes: 32 additions & 2 deletions vegafusion-runtime/tests/test_projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ fn crate_dir() -> String {
#[cfg(test)]
mod test_custom_specs {
use crate::crate_dir;
use itertools::Itertools;
use rstest::rstest;
use std::fs;
use vegafusion_core::planning::plan::{PlannerConfig, SpecPlan};
use vegafusion_core::planning::projection_pushdown::get_column_usage;
use vegafusion_core::spec::chart::ChartSpec;
use vegafusion_core::spec::transform::TransformSpec;

Expand All @@ -32,7 +34,7 @@ mod test_custom_specs {
"bin_maxbins_60_IMDB Rating_end"
]),
)]
fn test(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) {
fn test_proj_pushdown(spec_name: &str, data_index: usize, projection_fields: Vec<&str>) {
// Load spec
let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name);
let spec_str = fs::read_to_string(spec_path).unwrap();
Expand All @@ -49,12 +51,40 @@ mod test_custom_specs {

// Print data
// println!("{}", serde_json::to_string_pretty(&spec_plan.server_spec.data).unwrap());
let expected_fields: Vec<_> = projection_fields
.into_iter()
.map(String::from)
.sorted()
.collect();

if let TransformSpec::Project(project) = tx {
let expected_fields: Vec<_> = projection_fields.iter().map(|f| f.to_string()).collect();
assert_eq!(project.fields, expected_fields);
} else {
panic!("Expected project transform")
}
}

# [rstest(
spec_name,
projection_fields,
case("vegalite/point_2d", vec!["Horsepower", "Miles_per_Gallon"]),
case("vegalite/point_bubble", vec!["Acceleration", "Horsepower", "Miles_per_Gallon"]),
case("vegalite/concat_marginal_histograms", vec![
"IMDB Rating",
"Rotten Tomatoes Rating",
]),
case("vegalite/rect_binned_heatmap", vec![
"IMDB Rating",
"Rotten Tomatoes Rating",
]),
)]
fn test_get_column_usage(spec_name: &str, projection_fields: Vec<&str>) {
// Load spec
let spec_path = format!("{}/tests/specs/{}.vg.json", crate_dir(), spec_name);
let spec_str = fs::read_to_string(spec_path).unwrap();
let spec: ChartSpec = serde_json::from_str(&spec_str).unwrap();
let root_usages = get_column_usage(&spec).unwrap();
let expected_fields: Vec<_> = projection_fields.into_iter().map(String::from).collect();
assert_eq!(root_usages["source_0"], Some(expected_fields));
}
}
Loading