From f2d4f7381d2921349f321ac4959a57a99622908a Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Wed, 27 Nov 2024 17:35:09 -0800 Subject: [PATCH] [PERF] Improve hash table probe side decisions for Swordfish (#3327) This PR lifts statistics into optimized logical plans so that they're available for local execution plans. It then uses these newly available statistics to make better decisions on whether to build the probe table of a hash join on the left or right side. ## Benchmark results For TPC-H, this gives us some notable speedups with Q5, Q8, and Q19. - Q5: ~2.2x (in memory), ~1.5 (parquet) - Q8: ~5.9x (in memory), ~2.2x (parquet) - Q19: ~7x (in memory), ~4.5x (parquet) Crucially, with this change, our native runner is now faster (or within some small deviation) than our previous python runner for all 22 TPC-H queries. For more detailed results, we have: ### Q5 #### Before ``` --------------------------------------------------------------------------------- benchmark 'q5-parts-1': 4 tests --------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-python-5] 28.7213 (1.0) 30.3664 (1.0) 29.4645 (1.0) 0.3849 (1.0) 29.4894 (1.0) 0.5529 (1.0) 10;0 33.9391 (1.0) 33 1 test_tpch[1-in-memory-native-5] 30.9980 (1.08) 34.1489 (1.12) 32.2151 (1.09) 0.7150 (1.86) 32.2903 (1.09) 1.1586 (2.10) 10;0 31.0413 (0.91) 31 1 test_tpch[1-parquet-python-5] 48.8010 (1.70) 51.8535 (1.71) 50.0985 (1.70) 0.8342 (2.17) 50.0400 (1.70) 1.4193 (2.57) 9;0 19.9607 (0.59) 20 1 test_tpch[1-parquet-native-5] 51.1122 (1.78) 54.0755 (1.78) 52.3799 (1.78) 0.8317 (2.16) 52.4268 (1.78) 1.2526 (2.27) 8;0 19.0913 (0.56) 20 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` #### After ``` --------------------------------------------------------------------------------- benchmark 'q5-parts-1': 4 tests --------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-native-5] 13.8393 (1.0) 16.2712 (1.0) 14.4352 (1.0) 0.5322 (2.13) 14.3212 (1.0) 0.5084 (1.71) 9;5 69.2750 (1.0) 62 1 test_tpch[1-in-memory-python-5] 28.3915 (2.05) 29.5304 (1.81) 28.8477 (2.00) 0.2501 (1.0) 28.8299 (2.01) 0.2971 (1.0) 10;1 34.6649 (0.50) 35 1 test_tpch[1-parquet-native-5] 34.3952 (2.49) 36.4230 (2.24) 35.4314 (2.45) 0.4543 (1.82) 35.3678 (2.47) 0.5042 (1.70) 8;1 28.2235 (0.41) 28 1 test_tpch[1-parquet-python-5] 55.7339 (4.03) 57.3051 (3.52) 56.4564 (3.91) 0.5101 (2.04) 56.2507 (3.93) 0.7570 (2.55) 4;0 17.7128 (0.26) 18 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` ### Q8 #### Before ``` --------------------------------------------------------------------------------- benchmark 'q8-parts-1': 4 tests --------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-python-8] 14.9532 (1.0) 17.2992 (1.0) 16.1906 (1.0) 0.5962 (1.22) 16.1370 (1.0) 0.8323 (1.13) 23;0 61.7642 (1.0) 60 1 test_tpch[1-parquet-python-8] 34.7310 (2.32) 52.5183 (3.04) 37.1660 (2.30) 3.2354 (6.61) 36.5454 (2.26) 2.2469 (3.04) 1;1 26.9063 (0.44) 28 1 test_tpch[1-in-memory-native-8] 44.0259 (2.94) 46.0576 (2.66) 45.0905 (2.78) 0.4898 (1.0) 45.0528 (2.79) 0.7380 (1.0) 5;0 22.1776 (0.36) 22 1 test_tpch[1-parquet-native-8] 69.8245 (4.67) 73.1332 (4.23) 71.0333 (4.39) 0.8421 (1.72) 70.8515 (4.39) 0.9827 (1.33) 3;1 14.0779 (0.23) 14 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` #### After ``` --------------------------------------------------------------------------------- benchmark 'q8-parts-1': 4 tests ---------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-native-8] 7.2145 (1.0) 8.5953 (1.0) 7.7980 (1.0) 0.3170 (1.0) 7.8074 (1.0) 0.4959 (1.13) 44;0 128.2373 (1.0) 118 1 test_tpch[1-in-memory-python-8] 15.3458 (2.13) 17.6777 (2.06) 16.3485 (2.10) 0.5217 (1.65) 16.5234 (2.12) 0.7077 (1.62) 17;0 61.1678 (0.48) 57 1 test_tpch[1-parquet-native-8] 31.8017 (4.41) 33.5568 (3.90) 32.4955 (4.17) 0.4402 (1.39) 32.4145 (4.15) 0.4376 (1.0) 9;2 30.7735 (0.24) 30 1 test_tpch[1-parquet-python-8] 46.2703 (6.41) 48.3345 (5.62) 47.6753 (6.11) 0.5142 (1.62) 47.7092 (6.11) 0.4509 (1.03) 6;2 20.9752 (0.16) 21 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` ### Q19 #### Before ``` ------------------------------------------------------------------------------------ benchmark 'q19-parts-1': 4 tests ----------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-python-19] 277.9533 (1.0) 281.5647 (1.0) 279.8556 (1.0) 1.4607 (1.0) 279.4186 (1.0) 2.2466 (1.0) 2;0 3.5733 (1.0) 5 1 test_tpch[1-parquet-python-19] 311.1196 (1.12) 317.8069 (1.13) 315.6200 (1.13) 2.6611 (1.82) 316.7545 (1.13) 2.8849 (1.28) 1;0 3.1684 (0.89) 5 1 test_tpch[1-in-memory-native-19] 431.2738 (1.55) 464.2194 (1.65) 442.1488 (1.58) 13.3136 (9.11) 436.8320 (1.56) 16.3197 (7.26) 1;0 2.2617 (0.63) 5 1 test_tpch[1-parquet-native-19] 455.3492 (1.64) 460.8460 (1.64) 458.0333 (1.64) 2.1169 (1.45) 457.4410 (1.64) 3.0005 (1.34) 2;0 2.1832 (0.61) 5 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` #### After ``` ------------------------------------------------------------------------------------ benchmark 'q19-parts-1': 4 tests ----------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_tpch[1-in-memory-native-19] 62.6100 (1.0) 71.2192 (1.0) 66.0757 (1.0) 2.5634 (1.0) 65.6719 (1.0) 1.5398 (1.0) 6;4 15.1342 (1.0) 15 1 test_tpch[1-parquet-native-19] 94.8984 (1.52) 134.7584 (1.89) 103.1099 (1.56) 12.5712 (4.90) 97.7583 (1.49) 7.3370 (4.76) 1;1 9.6984 (0.64) 9 1 test_tpch[1-in-memory-python-19] 284.6653 (4.55) 295.5558 (4.15) 289.7268 (4.38) 3.9986 (1.56) 288.6982 (4.40) 4.6399 (3.01) 2;0 3.4515 (0.23) 5 1 test_tpch[1-parquet-python-19] 308.9599 (4.93) 319.0440 (4.48) 314.7801 (4.76) 4.2088 (1.64) 315.3274 (4.80) 7.0198 (4.56) 2;0 3.1768 (0.21) 5 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` --------- Co-authored-by: Sammy Sidhu --- Cargo.lock | 11 ++ daft/dataframe/dataframe.py | 5 +- src/common/daft-config/src/lib.rs | 2 +- src/common/scan-info/src/lib.rs | 34 +++- src/common/scan-info/src/scan_operator.rs | 7 +- src/common/scan-info/src/scan_task.rs | 27 ++- src/common/scan-info/src/test/mod.rs | 18 +- src/daft-catalog/src/lib.rs | 12 +- src/daft-connect/src/op/execute/root.rs | 6 +- src/daft-local-execution/src/pipeline.rs | 51 +++++- .../src/sinks/outer_hash_join_probe.rs | 26 +++ src/daft-local-plan/src/plan.rs | 153 +++++++++++------ src/daft-local-plan/src/translate.rs | 49 +++++- src/daft-logical-plan/src/builder.rs | 4 +- src/daft-logical-plan/src/display.rs | 28 +-- src/daft-logical-plan/src/lib.rs | 1 + src/daft-logical-plan/src/logical_plan.rs | 85 ++++++++-- .../src/ops/actor_pool_project.rs | 13 ++ src/daft-logical-plan/src/ops/agg.rs | 42 +++++ src/daft-logical-plan/src/ops/concat.rs | 38 ++++- src/daft-logical-plan/src/ops/distinct.rs | 35 +++- src/daft-logical-plan/src/ops/explode.rs | 18 ++ src/daft-logical-plan/src/ops/filter.rs | 32 +++- src/daft-logical-plan/src/ops/join.rs | 52 ++++++ src/daft-logical-plan/src/ops/limit.rs | 46 ++++- .../src/ops/monotonically_increasing_id.rs | 19 ++- src/daft-logical-plan/src/ops/pivot.rs | 13 ++ src/daft-logical-plan/src/ops/project.rs | 19 ++- src/daft-logical-plan/src/ops/repartition.rs | 13 ++ src/daft-logical-plan/src/ops/sample.rs | 20 ++- .../src/ops/set_operations.rs | 5 +- src/daft-logical-plan/src/ops/sink.rs | 18 +- src/daft-logical-plan/src/ops/sort.rs | 14 +- src/daft-logical-plan/src/ops/source.rs | 87 +++++++++- src/daft-logical-plan/src/ops/unpivot.rs | 42 +++++ .../src/optimization/optimizer.rs | 16 +- .../rules/eliminate_cross_join.rs | 58 ++++--- .../optimization/rules/enrich_with_stats.rs | 27 +++ .../optimization/rules/materialize_scans.rs | 47 +++++ .../src/optimization/rules/mod.rs | 4 + .../optimization/rules/push_down_filter.rs | 4 +- .../src/optimization/rules/push_down_limit.rs | 9 +- src/daft-logical-plan/src/source_info/mod.rs | 4 +- src/daft-logical-plan/src/stats.rs | 146 ++++++++++++++++ src/daft-parquet/src/metadata.rs | 4 +- src/daft-physical-plan/src/ops/scan.rs | 8 +- .../src/physical_planner/planner.rs | 9 +- .../src/physical_planner/translate.rs | 32 +++- src/daft-physical-plan/src/plan.rs | 73 +++----- src/daft-scan/Cargo.toml | 1 + src/daft-scan/src/anonymous.rs | 43 ++--- src/daft-scan/src/glob.rs | 160 ++++++++---------- src/daft-scan/src/lib.rs | 73 +++++++- src/daft-scan/src/python.rs | 58 ++++--- src/daft-scan/src/scan_task_iters.rs | 43 ++++- src/daft-sql/src/lib.rs | 44 ++--- src/daft-stats/src/column_stats/mod.rs | 23 ++- src/daft-stats/src/partition_spec.rs | 17 +- src/daft-stats/src/table_metadata.rs | 2 +- src/daft-stats/src/table_stats.rs | 10 ++ src/daft-table/src/lib.rs | 12 ++ 61 files changed, 1556 insertions(+), 416 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs create mode 100644 src/daft-logical-plan/src/optimization/rules/materialize_scans.rs create mode 100644 src/daft-logical-plan/src/stats.rs diff --git a/Cargo.lock b/Cargo.lock index fd22dcfa10..0771b3414b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1843,6 +1843,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctor" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +dependencies = [ + "quote", + "syn 2.0.87", +] + [[package]] name = "daft" version = "0.3.0-dev0" @@ -2398,6 +2408,7 @@ dependencies = [ "common-py-serde", "common-runtime", "common-scan-info", + "ctor", "daft-core", "daft-csv", "daft-decoding", diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index f971a078ad..77528da220 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -184,9 +184,10 @@ def explain( return None def num_partitions(self) -> int: - daft_execution_config = get_context().daft_execution_config # We need to run the optimizer since that could change the number of partitions - return self.__builder.optimize().to_physical_plan_scheduler(daft_execution_config).num_partitions() + return ( + self.__builder.optimize().to_physical_plan_scheduler(get_context().daft_execution_config).num_partitions() + ) @DataframePublicAPI def schema(self) -> Schema: diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index bbd53b6024..ac8600936d 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -37,7 +37,7 @@ impl DaftPlanningConfig { /// 3. Task generation from physical plan /// 4. Task scheduling /// 5. Task local execution -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct DaftExecutionConfig { pub scan_tasks_min_size_bytes: usize, pub scan_tasks_max_size_bytes: usize, diff --git a/src/common/scan-info/src/lib.rs b/src/common/scan-info/src/lib.rs index ba3a201614..44727c61df 100644 --- a/src/common/scan-info/src/lib.rs +++ b/src/common/scan-info/src/lib.rs @@ -10,7 +10,7 @@ mod scan_operator; mod scan_task; pub mod test; -use std::{fmt::Debug, hash::Hash}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; use daft_schema::schema::SchemaRef; pub use expr_rewriter::{rewrite_predicate_for_partitioning, PredicateGroups}; @@ -19,11 +19,35 @@ pub use pushdowns::Pushdowns; #[cfg(feature = "python")] pub use python::register_modules; pub use scan_operator::{ScanOperator, ScanOperatorRef}; -pub use scan_task::{BoxScanTaskLikeIter, ScanTaskLike, ScanTaskLikeRef}; +pub use scan_task::{ScanTaskLike, ScanTaskLikeRef, SPLIT_AND_MERGE_PASS}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScanState { + Operator(ScanOperatorRef), + Tasks(Arc>), +} + +impl ScanState { + pub fn multiline_display(&self) -> Vec { + match self { + Self::Operator(scan_op) => scan_op.0.multiline_display(), + Self::Tasks(scan_tasks) => { + vec![format!("Num Scan Tasks = {}", scan_tasks.len())] + } + } + } + + pub fn get_scan_op(&self) -> &ScanOperatorRef { + match self { + Self::Operator(scan_op) => scan_op, + Self::Tasks(_) => panic!("Tried to get scan op from materialized physical scan info"), + } + } +} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PhysicalScanInfo { - pub scan_op: ScanOperatorRef, + pub scan_state: ScanState, pub source_schema: SchemaRef, pub partitioning_keys: Vec, pub pushdowns: Pushdowns, @@ -38,7 +62,7 @@ impl PhysicalScanInfo { pushdowns: Pushdowns, ) -> Self { Self { - scan_op, + scan_state: ScanState::Operator(scan_op), source_schema, partitioning_keys, pushdowns, @@ -48,7 +72,7 @@ impl PhysicalScanInfo { #[must_use] pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { Self { - scan_op: self.scan_op.clone(), + scan_state: self.scan_state.clone(), source_schema: self.source_schema.clone(), partitioning_keys: self.partitioning_keys.clone(), pushdowns, diff --git a/src/common/scan-info/src/scan_operator.rs b/src/common/scan-info/src/scan_operator.rs index b965f62bb6..10c798fe5f 100644 --- a/src/common/scan-info/src/scan_operator.rs +++ b/src/common/scan-info/src/scan_operator.rs @@ -4,7 +4,6 @@ use std::{ sync::Arc, }; -use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_schema::schema::SchemaRef; @@ -33,11 +32,7 @@ pub trait ScanOperator: Send + Sync + Debug { /// If cfg provided, `to_scan_tasks` should apply the appropriate transformations /// (merging, splitting) to the outputted scan tasks - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - config: Option<&DaftExecutionConfig>, - ) -> DaftResult>; + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult>; } impl Display for dyn ScanOperator { diff --git a/src/common/scan-info/src/scan_task.rs b/src/common/scan-info/src/scan_task.rs index 3d6e5f466b..886fe42891 100644 --- a/src/common/scan-info/src/scan_task.rs +++ b/src/common/scan-info/src/scan_task.rs @@ -1,4 +1,9 @@ -use std::{any::Any, fmt::Debug, sync::Arc}; +use std::{ + any::Any, + fmt::Debug, + hash::{Hash, Hasher}, + sync::{Arc, OnceLock}, +}; use common_daft_config::DaftExecutionConfig; use common_display::DisplayAs; @@ -13,6 +18,7 @@ pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync { fn as_any(&self) -> &dyn Any; fn as_any_arc(self: Arc) -> Arc; fn dyn_eq(&self, other: &dyn ScanTaskLike) -> bool; + fn dyn_hash(&self, state: &mut dyn Hasher); #[must_use] fn materialized_schema(&self) -> SchemaRef; #[must_use] @@ -35,10 +41,27 @@ pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync { pub type ScanTaskLikeRef = Arc; +impl Eq for dyn ScanTaskLike + '_ {} + impl PartialEq for dyn ScanTaskLike + '_ { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other) } } -pub type BoxScanTaskLikeIter = Box>>>; +impl Hash for dyn ScanTaskLike + '_ { + fn hash(&self, state: &mut H) { + self.dyn_hash(state); + } +} + +// Forward declare splitting and merging pass so that scan tasks can be split and merged +// with common/scan-info without importing daft-scan. +pub type SplitAndMergePass = dyn Fn( + Arc>, + &Pushdowns, + &DaftExecutionConfig, + ) -> DaftResult>> + + Sync + + Send; +pub static SPLIT_AND_MERGE_PASS: OnceLock<&SplitAndMergePass> = OnceLock::new(); diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index c5b14039c1..2fd717db1b 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -1,4 +1,8 @@ -use std::{any::Any, sync::Arc}; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; use common_daft_config::DaftExecutionConfig; use common_display::DisplayAs; @@ -9,7 +13,7 @@ use serde::{Deserialize, Serialize}; use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeRef}; -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash)] struct DummyScanTask { pub schema: SchemaRef, pub pushdowns: Pushdowns, @@ -38,6 +42,10 @@ impl ScanTaskLike for DummyScanTask { .map_or(false, |a| a == self) } + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state); + } + fn materialized_schema(&self) -> SchemaRef { self.schema.clone() } @@ -121,11 +129,7 @@ impl ScanOperator for DummyScanOperator { vec!["DummyScanOperator".to_string()] } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - _: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let scan_task = Arc::new(DummyScanTask { schema: self.schema.clone(), pushdowns, diff --git a/src/daft-catalog/src/lib.rs b/src/daft-catalog/src/lib.rs index 8492e6ae37..73f75864c8 100644 --- a/src/daft-catalog/src/lib.rs +++ b/src/daft-catalog/src/lib.rs @@ -168,21 +168,21 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } #[test] fn test_register_and_unregister_named_table() { let mut catalog = DaftMetaCatalog::new_from_env(); - let plan = LogicalPlanBuilder::new(mock_plan(), None); + let plan = LogicalPlanBuilder::from(mock_plan()); // Register a table assert!(catalog @@ -198,7 +198,7 @@ mod tests { #[test] fn test_read_registered_table() { let mut catalog = DaftMetaCatalog::new_from_env(); - let plan = LogicalPlanBuilder::new(mock_plan(), None); + let plan = LogicalPlanBuilder::from(mock_plan()); catalog.register_named_table("test_table", plan).unwrap(); diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..4f765243c8 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, future::ready}; +use std::{collections::HashMap, future::ready, sync::Arc}; use common_daft_config::DaftExecutionConfig; use daft_local_execution::NativeExecutor; @@ -33,10 +33,10 @@ impl Session { let execution_fut = async { let plan = translation::to_logical_plan(command)?; let optimized_plan = plan.optimize()?; - let cfg = DaftExecutionConfig::default(); + let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; let mut result_stream = native_executor - .run(HashMap::new(), cfg.into(), None)? + .run(HashMap::new(), cfg, None)? .into_stream(); while let Some(result) = result_stream.next().await { diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 6f804c150f..c931614ff3 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -15,7 +15,7 @@ use daft_local_plan::{ Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; -use daft_logical_plan::JoinType; +use daft_logical_plan::{stats::StatsState, JoinType}; use daft_micropartition::MicroPartition; use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; @@ -319,18 +319,54 @@ pub fn physical_plan_to_pipeline( null_equals_null, join_type, schema, + .. }) => { let left_schema = left.schema(); let right_schema = right.schema(); - // Determine the build and probe sides based on the join type - // Currently it is a naive determination, in the future we should leverage the cardinality of the tables - // to determine the build and probe sides + // To determine whether to use the left or right side of a join for building a probe table, we consider: + // 1. Cardinality of the sides. Probe tables should be built on the smaller side. + // 2. Join type. Different join types have different requirements for which side can build the probe table. + let left_stats_state = left.get_stats_state(); + let right_stats_state = right.get_stats_state(); + let build_on_left = match (left_stats_state, right_stats_state) { + (StatsState::Materialized(left_stats), StatsState::Materialized(right_stats)) => { + left_stats.approx_stats.upper_bound_bytes + <= right_stats.approx_stats.upper_bound_bytes + } + // If stats are only available on the right side of the join, and the upper bound bytes on the + // right are under the broadcast join size threshold, we build on the right instead of the left. + (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => right_stats + .approx_stats + .upper_bound_bytes + .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold), + // If stats are not available, we fall back and build on the left by default. + _ => true, + }; + + // TODO(desmond): We might potentially want to flip the probe table side for + // left/right outer joins if one side is significantly larger. Needs to be tuned. + // + // In greater detail, consider a right outer join where the left side is several orders + // of magnitude larger than the right. An extreme example might have 1B rows on the left, + // and 10 rows on the right. + // + // Typically we would build the probe table on the left, then stream rows from the right + // to match against the probe table. But in this case we would have a giant intermediate + // probe table. + // + // An alternative 2-pass algorithm would be to: + // 1. Build the probe table on the right, but add a second data structure to keep track of + // which rows on the right have been matched. + // 2. Stream rows on the left until all rows have been seen. + // 3. Finally, emit all unmatched rows from the right. let build_on_left = match join_type { - JoinType::Inner => true, - JoinType::Right => true, - JoinType::Outer => true, + JoinType::Inner => build_on_left, + JoinType::Outer => build_on_left, + // For left outer joins, we build on right so we can stream the left side. JoinType::Left => false, + // For right outer joins, we build on left so we can stream the right side. + JoinType::Right => true, JoinType::Anti | JoinType::Semi => false, }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { @@ -421,6 +457,7 @@ pub fn physical_plan_to_pipeline( left_schema, right_schema, *join_type, + build_on_left, common_join_keys, schema, probe_state_bridge, diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index a8ca50130f..4af93075a0 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -126,6 +126,7 @@ struct OuterHashJoinParams { right_non_join_columns: Vec, right_non_join_schema: SchemaRef, join_type: JoinType, + build_on_left: bool, } pub(crate) struct OuterHashJoinProbeSink { @@ -134,16 +135,23 @@ pub(crate) struct OuterHashJoinProbeSink { probe_state_bridge: ProbeStateBridgeRef, } +#[allow(clippy::too_many_arguments)] impl OuterHashJoinProbeSink { pub(crate) fn new( probe_on: Vec, left_schema: &SchemaRef, right_schema: &SchemaRef, join_type: JoinType, + build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, probe_state_bridge: ProbeStateBridgeRef, ) -> Self { + // For outer joins, we need to swap the left and right schemas if we are building on the right. + let (left_schema, right_schema) = match (join_type, build_on_left) { + (JoinType::Outer, false) => (right_schema, left_schema), + _ => (left_schema, right_schema), + }; let left_non_join_columns = left_schema .fields .keys() @@ -168,6 +176,7 @@ impl OuterHashJoinProbeSink { right_non_join_columns, right_non_join_schema, join_type, + build_on_left, }), output_schema: output_schema.clone(), probe_state_bridge, @@ -243,6 +252,7 @@ impl OuterHashJoinProbeSink { ))) } + #[allow(clippy::too_many_arguments)] fn probe_outer( input: &Arc, probe_state: &ProbeState, @@ -251,6 +261,7 @@ impl OuterHashJoinProbeSink { common_join_keys: &[String], left_non_join_columns: &[String], right_non_join_columns: &[String], + build_on_left: bool, ) -> DaftResult> { let probe_table = probe_state.get_probeable().clone(); let tables = probe_state.get_tables().clone(); @@ -297,6 +308,12 @@ impl OuterHashJoinProbeSink { let join_table = probe_side_table.get_columns(common_join_keys)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = probe_side_table.get_columns(right_non_join_columns)?; + // If we built the probe table on the right, flip the order of union. + let (left, right) = if build_on_left { + (left, right) + } else { + (right, left) + }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -310,6 +327,7 @@ impl OuterHashJoinProbeSink { common_join_keys: &[String], left_non_join_columns: &[String], right_non_join_schema: &SchemaRef, + build_on_left: bool, ) -> DaftResult>> { let mut states_iter = states.iter_mut(); let first_state = states_iter @@ -372,6 +390,12 @@ impl OuterHashJoinProbeSink { .collect::>(); Table::new_unchecked(right_non_join_schema.clone(), columns, left.len()) }; + // If we built the probe table on the right, flip the order of union. + let (left, right) = if build_on_left { + (left, right) + } else { + (right, left) + }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Some(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -426,6 +450,7 @@ impl StreamingSink for OuterHashJoinProbeSink { ¶ms.common_join_keys, ¶ms.left_non_join_columns, ¶ms.right_non_join_columns, + params.build_on_left, ) } _ => unreachable!( @@ -462,6 +487,7 @@ impl StreamingSink for OuterHashJoinProbeSink { ¶ms.common_join_keys, ¶ms.left_non_join_columns, ¶ms.right_non_join_schema, + params.build_on_left, ) .await }) diff --git a/src/daft-local-plan/src/plan.rs b/src/daft-local-plan/src/plan.rs index 26f796b1e4..7d541421a3 100644 --- a/src/daft-local-plan/src/plan.rs +++ b/src/daft-local-plan/src/plan.rs @@ -4,7 +4,10 @@ use common_resource_request::ResourceRequest; use common_scan_info::{Pushdowns, ScanTaskLikeRef}; use daft_core::prelude::*; use daft_dsl::{AggExpr, ExprRef}; -use daft_logical_plan::{InMemoryInfo, OutputFileInfo}; +use daft_logical_plan::{ + stats::{PlanStats, StatsState}, + InMemoryInfo, OutputFileInfo, +}; pub type LocalPhysicalPlanRef = Arc; #[derive(Debug, strum::IntoStaticStr)] @@ -56,24 +59,54 @@ impl LocalPhysicalPlan { self.into() } - pub(crate) fn in_memory_scan(in_memory_info: InMemoryInfo) -> LocalPhysicalPlanRef { + pub fn get_stats_state(&self) -> &StatsState { + match self { + Self::InMemoryScan(InMemoryScan { stats_state, .. }) + | Self::PhysicalScan(PhysicalScan { stats_state, .. }) + | Self::EmptyScan(EmptyScan { stats_state, .. }) + | Self::Project(Project { stats_state, .. }) + | Self::ActorPoolProject(ActorPoolProject { stats_state, .. }) + | Self::Filter(Filter { stats_state, .. }) + | Self::Limit(Limit { stats_state, .. }) + | Self::Explode(Explode { stats_state, .. }) + | Self::Unpivot(Unpivot { stats_state, .. }) + | Self::Sort(Sort { stats_state, .. }) + | Self::Sample(Sample { stats_state, .. }) + | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { stats_state, .. }) + | Self::UnGroupedAggregate(UnGroupedAggregate { stats_state, .. }) + | Self::HashAggregate(HashAggregate { stats_state, .. }) + | Self::Pivot(Pivot { stats_state, .. }) + | Self::Concat(Concat { stats_state, .. }) + | Self::HashJoin(HashJoin { stats_state, .. }) + | Self::PhysicalWrite(PhysicalWrite { stats_state, .. }) => stats_state, + #[cfg(feature = "python")] + Self::CatalogWrite(CatalogWrite { stats_state, .. }) + | Self::LanceWrite(LanceWrite { stats_state, .. }) => stats_state, + } + } + + pub(crate) fn in_memory_scan( + in_memory_info: InMemoryInfo, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { Self::InMemoryScan(InMemoryScan { info: in_memory_info, - plan_stats: PlanStats {}, + stats_state, }) .arced() } pub(crate) fn physical_scan( - scan_tasks: Vec, + scan_tasks: Arc>, pushdowns: Pushdowns, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::PhysicalScan(PhysicalScan { scan_tasks, pushdowns, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -81,29 +114,37 @@ impl LocalPhysicalPlan { pub(crate) fn empty_scan(schema: SchemaRef) -> LocalPhysicalPlanRef { Self::EmptyScan(EmptyScan { schema, - plan_stats: PlanStats {}, + stats_state: StatsState::Materialized(PlanStats::empty().into()), }) .arced() } - pub(crate) fn filter(input: LocalPhysicalPlanRef, predicate: ExprRef) -> LocalPhysicalPlanRef { + pub(crate) fn filter( + input: LocalPhysicalPlanRef, + predicate: ExprRef, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Filter(Filter { input, predicate, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } - pub(crate) fn limit(input: LocalPhysicalPlanRef, num_rows: i64) -> LocalPhysicalPlanRef { + pub(crate) fn limit( + input: LocalPhysicalPlanRef, + num_rows: i64, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Limit(Limit { input, num_rows, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -112,12 +153,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, to_explode: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Explode(Explode { input, to_explode, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -126,12 +168,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, projection: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Project(Project { input, projection, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -140,12 +183,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, projection: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::ActorPoolProject(ActorPoolProject { input, projection, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -154,12 +198,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, aggregations: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::UnGroupedAggregate(UnGroupedAggregate { input, aggregations, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -169,13 +214,14 @@ impl LocalPhysicalPlan { aggregations: Vec, group_by: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::HashAggregate(HashAggregate { input, aggregations, group_by, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -187,6 +233,7 @@ impl LocalPhysicalPlan { variable_name: String, value_name: String, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Unpivot(Unpivot { input, @@ -195,11 +242,12 @@ impl LocalPhysicalPlan { variable_name, value_name, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } + #[allow(clippy::too_many_arguments)] pub(crate) fn pivot( input: LocalPhysicalPlanRef, group_by: Vec, @@ -208,6 +256,7 @@ impl LocalPhysicalPlan { aggregation: AggExpr, names: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Pivot(Pivot { input, @@ -217,7 +266,7 @@ impl LocalPhysicalPlan { aggregation, names, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -227,15 +276,16 @@ impl LocalPhysicalPlan { sort_by: Vec, descending: Vec, nulls_first: Vec, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Sort(Sort { input, sort_by, - nulls_first, descending, + nulls_first, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -245,6 +295,7 @@ impl LocalPhysicalPlan { fraction: f64, with_replacement: bool, seed: Option, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Sample(Sample { @@ -253,7 +304,7 @@ impl LocalPhysicalPlan { with_replacement, seed, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -262,16 +313,18 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, column_name: String, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, column_name, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } + #[allow(clippy::too_many_arguments)] pub(crate) fn hash_join( left: LocalPhysicalPlanRef, right: LocalPhysicalPlanRef, @@ -280,6 +333,7 @@ impl LocalPhysicalPlan { null_equals_null: Option>, join_type: JoinType, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::HashJoin(HashJoin { left, @@ -289,6 +343,7 @@ impl LocalPhysicalPlan { null_equals_null, join_type, schema, + stats_state, }) .arced() } @@ -296,13 +351,14 @@ impl LocalPhysicalPlan { pub(crate) fn concat( input: LocalPhysicalPlanRef, other: LocalPhysicalPlanRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Concat(Concat { input, other, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -312,13 +368,14 @@ impl LocalPhysicalPlan { data_schema: SchemaRef, file_schema: SchemaRef, file_info: OutputFileInfo, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::PhysicalWrite(PhysicalWrite { input, data_schema, file_schema, file_info, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -329,13 +386,14 @@ impl LocalPhysicalPlan { catalog_type: daft_logical_plan::CatalogType, data_schema: SchemaRef, file_schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::CatalogWrite(CatalogWrite { input, catalog_type, data_schema, file_schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -346,13 +404,14 @@ impl LocalPhysicalPlan { lance_info: daft_logical_plan::LanceCatalogInfo, data_schema: SchemaRef, file_schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::LanceWrite(LanceWrite { input, lance_info, data_schema, file_schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -388,21 +447,21 @@ impl LocalPhysicalPlan { #[derive(Debug)] pub struct InMemoryScan { pub info: InMemoryInfo, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] pub struct PhysicalScan { - pub scan_tasks: Vec, + pub scan_tasks: Arc>, pub pushdowns: Pushdowns, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] pub struct EmptyScan { pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -410,7 +469,7 @@ pub struct Project { pub input: LocalPhysicalPlanRef, pub projection: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -418,7 +477,7 @@ pub struct ActorPoolProject { pub input: LocalPhysicalPlanRef, pub projection: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -426,7 +485,7 @@ pub struct Filter { pub input: LocalPhysicalPlanRef, pub predicate: ExprRef, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -434,7 +493,7 @@ pub struct Limit { pub input: LocalPhysicalPlanRef, pub num_rows: i64, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -442,7 +501,7 @@ pub struct Explode { pub input: LocalPhysicalPlanRef, pub to_explode: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -452,7 +511,7 @@ pub struct Sort { pub descending: Vec, pub nulls_first: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -462,7 +521,7 @@ pub struct Sample { pub with_replacement: bool, pub seed: Option, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -470,7 +529,7 @@ pub struct MonotonicallyIncreasingId { pub input: LocalPhysicalPlanRef, pub column_name: String, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -478,7 +537,7 @@ pub struct UnGroupedAggregate { pub input: LocalPhysicalPlanRef, pub aggregations: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -487,7 +546,7 @@ pub struct HashAggregate { pub aggregations: Vec, pub group_by: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -498,7 +557,7 @@ pub struct Unpivot { pub variable_name: String, pub value_name: String, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -510,7 +569,7 @@ pub struct Pivot { pub aggregation: AggExpr, pub names: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -522,6 +581,7 @@ pub struct HashJoin { pub null_equals_null: Option>, pub join_type: JoinType, pub schema: SchemaRef, + pub stats_state: StatsState, } #[derive(Debug)] @@ -529,7 +589,7 @@ pub struct Concat { pub input: LocalPhysicalPlanRef, pub other: LocalPhysicalPlanRef, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -538,7 +598,7 @@ pub struct PhysicalWrite { pub data_schema: SchemaRef, pub file_schema: SchemaRef, pub file_info: OutputFileInfo, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[cfg(feature = "python")] @@ -548,7 +608,7 @@ pub struct CatalogWrite { pub catalog_type: daft_logical_plan::CatalogType, pub data_schema: SchemaRef, pub file_schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[cfg(feature = "python")] @@ -558,8 +618,5 @@ pub struct LanceWrite { pub lance_info: daft_logical_plan::LanceCatalogInfo, pub data_schema: SchemaRef, pub file_schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } - -#[derive(Debug)] -pub struct PlanStats {} diff --git a/src/daft-local-plan/src/translate.rs b/src/daft-local-plan/src/translate.rs index b8f214dd4f..aac5046e46 100644 --- a/src/daft-local-plan/src/translate.rs +++ b/src/daft-local-plan/src/translate.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use common_error::{DaftError, DaftResult}; +use common_scan_info::ScanState; use daft_core::join::JoinStrategy; use daft_dsl::ExprRef; use daft_logical_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo}; @@ -9,10 +12,18 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { match plan.as_ref() { LogicalPlan::Source(source) => { match source.source_info.as_ref() { - SourceInfo::InMemory(info) => Ok(LocalPhysicalPlan::in_memory_scan(info.clone())), + SourceInfo::InMemory(info) => Ok(LocalPhysicalPlan::in_memory_scan( + info.clone(), + source.stats_state.clone(), + )), SourceInfo::Physical(info) => { // We should be able to pass the ScanOperator into the physical plan directly but we need to figure out the serialization story - let scan_tasks = info.scan_op.0.to_scan_tasks(info.pushdowns.clone(), None)?; + let scan_tasks = match &info.scan_state { + ScanState::Operator(scan_op) => { + Arc::new(scan_op.0.to_scan_tasks(info.pushdowns.clone())?) + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + }; if scan_tasks.is_empty() { Ok(LocalPhysicalPlan::empty_scan(source.output_schema.clone())) } else { @@ -20,6 +31,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { scan_tasks, info.pushdowns.clone(), source.output_schema.clone(), + source.stats_state.clone(), )) } } @@ -30,11 +42,19 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { } LogicalPlan::Filter(filter) => { let input = translate(&filter.input)?; - Ok(LocalPhysicalPlan::filter(input, filter.predicate.clone())) + Ok(LocalPhysicalPlan::filter( + input, + filter.predicate.clone(), + filter.stats_state.clone(), + )) } LogicalPlan::Limit(limit) => { let input = translate(&limit.input)?; - Ok(LocalPhysicalPlan::limit(input, limit.limit)) + Ok(LocalPhysicalPlan::limit( + input, + limit.limit, + limit.stats_state.clone(), + )) } LogicalPlan::Project(project) => { let input = translate(&project.input)?; @@ -42,6 +62,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, project.projection.clone(), project.projected_schema.clone(), + project.stats_state.clone(), )) } LogicalPlan::ActorPoolProject(actor_pool_project) => { @@ -50,6 +71,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, actor_pool_project.projection.clone(), actor_pool_project.projected_schema.clone(), + actor_pool_project.stats_state.clone(), )) } LogicalPlan::Sample(sample) => { @@ -59,6 +81,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { sample.fraction, sample.with_replacement, sample.seed, + sample.stats_state.clone(), )) } LogicalPlan::Aggregate(aggregate) => { @@ -68,6 +91,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, aggregate.aggregations.clone(), aggregate.output_schema.clone(), + aggregate.stats_state.clone(), )) } else { Ok(LocalPhysicalPlan::hash_aggregate( @@ -75,6 +99,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { aggregate.aggregations.clone(), aggregate.groupby.clone(), aggregate.output_schema.clone(), + aggregate.stats_state.clone(), )) } } @@ -87,6 +112,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { unpivot.variable_name.clone(), unpivot.value_name.clone(), unpivot.output_schema.clone(), + unpivot.stats_state.clone(), )) } LogicalPlan::Pivot(pivot) => { @@ -99,6 +125,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { pivot.aggregation.clone(), pivot.names.clone(), pivot.output_schema.clone(), + pivot.stats_state.clone(), )) } LogicalPlan::Sort(sort) => { @@ -108,6 +135,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { sort.sort_by.clone(), sort.descending.clone(), sort.nulls_first.clone(), + sort.stats_state.clone(), )) } LogicalPlan::Join(join) => { @@ -134,6 +162,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { join.null_equals_nulls.clone(), join.join_type, join.output_schema.clone(), + join.stats_state.clone(), )) } LogicalPlan::Distinct(distinct) => { @@ -150,12 +179,17 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { vec![], col_exprs, schema, + distinct.stats_state.clone(), )) } LogicalPlan::Concat(concat) => { let input = translate(&concat.input)?; let other = translate(&concat.other)?; - Ok(LocalPhysicalPlan::concat(input, other)) + Ok(LocalPhysicalPlan::concat( + input, + other, + concat.stats_state.clone(), + )) } LogicalPlan::Repartition(repartition) => { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); @@ -167,6 +201,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, monotonically_increasing_id.column_name.clone(), monotonically_increasing_id.schema.clone(), + monotonically_increasing_id.stats_state.clone(), )) } LogicalPlan::Sink(sink) => { @@ -179,6 +214,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { data_schema, sink.schema.clone(), info.clone(), + sink.stats_state.clone(), )), #[cfg(feature = "python")] SinkInfo::CatalogInfo(info) => match &info.catalog { @@ -189,6 +225,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { info.catalog.clone(), data_schema, sink.schema.clone(), + sink.stats_state.clone(), )) } daft_logical_plan::CatalogType::Lance(info) => { @@ -197,6 +234,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { info.clone(), data_schema, sink.schema.clone(), + sink.stats_state.clone(), )) } }, @@ -208,6 +246,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, explode.to_explode.clone(), explode.exploded_schema.clone(), + explode.stats_state.clone(), )) } _ => todo!("{} not yet implemented", plan.name()), diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index c945c80203..f40a55ed4f 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -134,7 +134,7 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = ops::Source::new(schema, source_info.into()).into(); - Ok(Self::new(logical_plan.into(), None)) + Ok(Self::from(Arc::new(logical_plan))) } pub fn table_scan( @@ -186,7 +186,7 @@ impl LogicalPlanBuilder { schema_with_generated_fields }; let logical_plan: LogicalPlan = ops::Source::new(output_schema, source_info.into()).into(); - Ok(Self::new(logical_plan.into(), None)) + Ok(Self::from(Arc::new(logical_plan))) } pub fn select(&self, to_select: Vec) -> DaftResult { diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index be83f5237b..84db90d273 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -51,14 +51,14 @@ mod test { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -71,25 +71,25 @@ mod test { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1(), None) + let subplan = LogicalPlanBuilder::from(plan_1()) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2(), None) + let subplan2 = LogicalPlanBuilder::from(plan_2()) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -99,7 +99,7 @@ mod test { .sort(vec![col("last_name")], vec![false], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan, None) + let plan = LogicalPlanBuilder::from(subplan) .join( subplan2, vec![col("id")], @@ -159,11 +159,11 @@ Project1 --> Limit0 #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display_simple() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1(), None) + let subplan = LogicalPlanBuilder::from(plan_1()) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2(), None) + let subplan2 = LogicalPlanBuilder::from(plan_2()) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -173,7 +173,7 @@ Project1 --> Limit0 .sort(vec![col("last_name")], vec![false], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan, None) + let plan = LogicalPlanBuilder::from(subplan) .join_with_null_safe_equal( subplan2, vec![col("id")], diff --git a/src/daft-logical-plan/src/lib.rs b/src/daft-logical-plan/src/lib.rs index 88f21f8797..5296d99a23 100644 --- a/src/daft-logical-plan/src/lib.rs +++ b/src/daft-logical-plan/src/lib.rs @@ -10,6 +10,7 @@ pub mod optimization; pub mod partitioning; pub mod sink_info; pub mod source_info; +pub mod stats; #[cfg(test)] mod test; mod treenode; diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 01c6b510c8..fc2f065038 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -8,6 +8,7 @@ use indexmap::IndexSet; use snafu::Snafu; pub use crate::ops::*; +use crate::stats::PlanStats; /// Logical plan for a Daft query. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -39,6 +40,7 @@ impl LogicalPlan { pub fn arced(self) -> Arc { Arc::new(self) } + pub fn schema(&self) -> SchemaRef { match self { Self::Source(Source { output_schema, .. }) => output_schema.clone(), @@ -197,29 +199,91 @@ impl LogicalPlan { } } + pub fn materialized_stats(&self) -> &PlanStats { + match self { + Self::Source(Source { stats_state, .. }) + | Self::Project(Project { stats_state, .. }) + | Self::ActorPoolProject(ActorPoolProject { stats_state, .. }) + | Self::Filter(Filter { stats_state, .. }) + | Self::Limit(Limit { stats_state, .. }) + | Self::Explode(Explode { stats_state, .. }) + | Self::Unpivot(Unpivot { stats_state, .. }) + | Self::Sort(Sort { stats_state, .. }) + | Self::Repartition(Repartition { stats_state, .. }) + | Self::Distinct(Distinct { stats_state, .. }) + | Self::Aggregate(Aggregate { stats_state, .. }) + | Self::Pivot(Pivot { stats_state, .. }) + | Self::Concat(Concat { stats_state, .. }) + | Self::Join(Join { stats_state, .. }) + | Self::Sink(Sink { stats_state, .. }) + | Self::Sample(Sample { stats_state, .. }) + | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { stats_state, .. }) => { + stats_state.materialized_stats() + } + Self::Intersect(_) => { + panic!("Intersect nodes should be optimized away before stats are materialized") + } + Self::Union(_) => { + panic!("Union nodes should be optimized away before stats are materialized") + } + } + } + + // Materializes stats over logical plans. If stats are already materialized, this function recomputes stats, which might be + // useful if stats become stale during query planning. + pub fn with_materialized_stats(self) -> Self { + match self { + Self::Source(plan) => Self::Source(plan.with_materialized_stats()), + Self::Project(plan) => Self::Project(plan.with_materialized_stats()), + Self::ActorPoolProject(plan) => Self::ActorPoolProject(plan.with_materialized_stats()), + Self::Filter(plan) => Self::Filter(plan.with_materialized_stats()), + Self::Limit(plan) => Self::Limit(plan.with_materialized_stats()), + Self::Explode(plan) => Self::Explode(plan.with_materialized_stats()), + Self::Unpivot(plan) => Self::Unpivot(plan.with_materialized_stats()), + Self::Sort(plan) => Self::Sort(plan.with_materialized_stats()), + Self::Repartition(plan) => Self::Repartition(plan.with_materialized_stats()), + Self::Distinct(plan) => Self::Distinct(plan.with_materialized_stats()), + Self::Aggregate(plan) => Self::Aggregate(plan.with_materialized_stats()), + Self::Pivot(plan) => Self::Pivot(plan.with_materialized_stats()), + Self::Concat(plan) => Self::Concat(plan.with_materialized_stats()), + Self::Intersect(_) => { + panic!("Intersect should be optimized away before stats are derived") + } + Self::Union(_) => { + panic!("Union should be optimized away before stats are derived") + } + Self::Join(plan) => Self::Join(plan.with_materialized_stats()), + Self::Sink(plan) => Self::Sink(plan.with_materialized_stats()), + Self::Sample(plan) => Self::Sample(plan.with_materialized_stats()), + Self::MonotonicallyIncreasingId(plan) => { + Self::MonotonicallyIncreasingId(plan.with_materialized_stats()) + } + } + } + pub fn multiline_display(&self) -> Vec { match self { Self::Source(source) => source.multiline_display(), Self::Project(projection) => projection.multiline_display(), Self::ActorPoolProject(projection) => projection.multiline_display(), - Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], - Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], + Self::Filter(filter) => filter.multiline_display(), + Self::Limit(limit) => limit.multiline_display(), Self::Explode(explode) => explode.multiline_display(), Self::Unpivot(unpivot) => unpivot.multiline_display(), Self::Sort(sort) => sort.multiline_display(), Self::Repartition(repartition) => repartition.multiline_display(), - Self::Distinct(_) => vec!["Distinct".to_string()], + Self::Distinct(distinct) => distinct.multiline_display(), Self::Aggregate(aggregate) => aggregate.multiline_display(), Self::Pivot(pivot) => pivot.multiline_display(), - Self::Concat(_) => vec!["Concat".to_string()], + Self::Concat(concat) => concat.multiline_display(), Self::Intersect(inner) => inner.multiline_display(), Self::Union(inner) => inner.multiline_display(), Self::Join(join) => join.multiline_display(), Self::Sink(sink) => sink.multiline_display(), - Self::Sample(sample) => { - vec![format!("Sample: {fraction}", fraction = sample.fraction)] + Self::Sample(sample) => sample.multiline_display(), + Self::MonotonicallyIncreasingId(monotonically_increasing_id) => { + monotonically_increasing_id.multiline_display() } - Self::MonotonicallyIncreasingId(_) => vec!["MonotonicallyIncreasingId".to_string()], } } @@ -237,7 +301,7 @@ impl LogicalPlan { Self::Distinct(Distinct { input, .. }) => vec![input], Self::Aggregate(Aggregate { input, .. }) => vec![input], Self::Pivot(Pivot { input, .. }) => vec![input], - Self::Concat(Concat { input, other }) => vec![input, other], + Self::Concat(Concat { input, other, .. }) => vec![input, other], Self::Join(Join { left, right, .. }) => vec![left, right], Self::Sink(Sink { input, .. }) => vec![input], Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs], @@ -267,7 +331,8 @@ impl LogicalPlan { Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), - Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), + Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => + Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name.clone(), value_name.clone(), output_schema.clone())), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"), Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"), @@ -377,7 +442,7 @@ macro_rules! impl_from_data_struct_for_logical_plan { impl From<$name> for Arc { fn from(data: $name) -> Self { - Arc::new(LogicalPlan::$name(data)) + Self::new(LogicalPlan::$name(data)) } } }; diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index fa1c8bb970..78ec2a681f 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -16,6 +16,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Error, Result}, + stats::StatsState, LogicalPlan, }; @@ -25,6 +26,7 @@ pub struct ActorPoolProject { pub input: Arc, pub projection: Vec, pub projected_schema: SchemaRef, + pub stats_state: StatsState, } impl ActorPoolProject { @@ -64,9 +66,17 @@ impl ActorPoolProject { input, projection, projected_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn resource_request(&self) -> Option { get_resource_request(self.projection.as_slice()) } @@ -115,6 +125,9 @@ impl ActorPoolProject { multiline_display.join(", ") )); } + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index be82d0d010..5b99338b1c 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -26,6 +27,7 @@ pub struct Aggregate { pub groupby: Vec, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Aggregate { @@ -55,9 +57,46 @@ impl Aggregate { aggregations, groupby, output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can use the schema here for better estimations. For now, use the old logic. + let input_stats = self.input.materialized_stats(); + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / (input_stats.approx_stats.lower_bound_rows.max(1)); + let est_bytes_per_row_upper = + input_stats + .approx_stats + .upper_bound_bytes + .and_then(|bytes| { + input_stats + .approx_stats + .upper_bound_rows + .map(|rows| bytes / rows.max(1)) + }); + let approx_stats = if self.groupby.is_empty() { + ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: Some(1), + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: est_bytes_per_row_upper, + } + } else { + ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: input_stats.approx_stats.upper_bound_rows, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + } + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -74,6 +113,9 @@ impl Aggregate { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/concat.rs b/src/daft-logical-plan/src/ops/concat.rs index 39541e39de..fb18441c4c 100644 --- a/src/daft-logical-plan/src/ops/concat.rs +++ b/src/daft-logical-plan/src/ops/concat.rs @@ -3,16 +3,29 @@ use std::sync::Arc; use common_error::DaftError; use snafu::ResultExt; -use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; +use crate::{ + logical_plan::{self, CreationSnafu}, + stats::{PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Concat { // Upstream nodes. pub input: Arc, pub other: Arc, + pub stats_state: StatsState, } impl Concat { + pub(crate) fn new(input: Arc, other: Arc) -> Self { + Self { + input, + other, + stats_state: StatsState::NotMaterialized, + } + } + pub(crate) fn try_new( input: Arc, other: Arc, @@ -26,6 +39,27 @@ impl Concat { ))) .context(CreationSnafu); } - Ok(Self { input, other }) + Ok(Self { + input, + other, + stats_state: StatsState::NotMaterialized, + }) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let other_stats = self.other.materialized_stats(); + let approx_stats = &input_stats.approx_stats + &other_stats.approx_stats; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Concat")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/distinct.rs b/src/daft-logical-plan/src/ops/distinct.rs index 11fe254e75..899dab940b 100644 --- a/src/daft-logical-plan/src/ops/distinct.rs +++ b/src/daft-logical-plan/src/ops/distinct.rs @@ -1,15 +1,46 @@ use std::sync::Arc; -use crate::LogicalPlan; +use crate::{ + stats::{ApproxStats, PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Distinct { // Upstream node. pub input: Arc, + pub stats_state: StatsState, } impl Distinct { pub(crate) fn new(input: Arc) -> Self { - Self { input } + Self { + input, + stats_state: StatsState::NotMaterialized, + } + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can simply use NDVs here. For now, do a naive estimation. + let input_stats = self.input.materialized_stats(); + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / (input_stats.approx_stats.lower_bound_rows.max(1)); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: input_stats.approx_stats.upper_bound_rows, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Distinct")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index daa7ca99b0..00624102f4 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -17,6 +18,7 @@ pub struct Explode { // Expressions to explode. e.g. col("a") pub to_explode: Vec, pub exploded_schema: SchemaRef, + pub stats_state: StatsState, } impl Explode { @@ -59,9 +61,22 @@ impl Explode { input, to_explode, exploded_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows, + upper_bound_rows: None, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, + upper_bound_bytes: None, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -69,6 +84,9 @@ impl Explode { self.to_explode.iter().map(|e| e.to_string()).join(", ") )); res.push(format!("Schema = {}", self.exploded_schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 0f12bf9a49..62bb34a46a 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Result}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -16,6 +17,7 @@ pub struct Filter { pub input: Arc, // The Boolean expression to filter on. pub predicate: ExprRef, + pub stats_state: StatsState, } impl Filter { @@ -33,6 +35,34 @@ impl Filter { ))) .context(CreationSnafu); } - Ok(Self { input, predicate }) + Ok(Self { + input, + predicate, + stats_state: StatsState::NotMaterialized, + }) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Assume no row/column pruning in cardinality-affecting operations. + // TODO(desmond): We can do better estimations here. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let upper_bound_rows = input_stats.approx_stats.upper_bound_rows; + let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes; + let approx_stats = ApproxStats { + lower_bound_rows: 0, + upper_bound_rows, + lower_bound_bytes: 0, + upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Filter: {}", self.predicate)]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index db787cf3a2..5484a5c701 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -18,6 +18,7 @@ use uuid::Uuid; use crate::{ logical_plan::{self, CreationSnafu}, ops::Project, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -33,6 +34,7 @@ pub struct Join { pub join_type: JoinType, pub join_strategy: Option, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl std::hash::Hash for Join { @@ -49,6 +51,30 @@ impl std::hash::Hash for Join { } impl Join { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + left: Arc, + right: Arc, + left_on: Vec, + right_on: Vec, + null_equals_nulls: Option>, + join_type: JoinType, + join_strategy: Option, + output_schema: SchemaRef, + ) -> Self { + Self { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state: StatsState::NotMaterialized, + } + } + #[allow(clippy::too_many_arguments)] pub(crate) fn try_new( left: Arc, @@ -129,6 +155,7 @@ impl Join { join_type, join_strategy, output_schema, + stats_state: StatsState::NotMaterialized, }) } else { let common_join_keys: HashSet<_> = @@ -224,6 +251,7 @@ impl Join { join_type, join_strategy, output_schema, + stats_state: StatsState::NotMaterialized, }) } } @@ -283,6 +311,27 @@ impl Join { .unzip() } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Assume a Primary-key + Foreign-Key join which would yield the max of the two tables. + // TODO(desmond): We can do better estimations here. For now, use the old logic. + let left_stats = self.left.materialized_stats(); + let right_stats = self.right.materialized_stats(); + let approx_stats = ApproxStats { + lower_bound_rows: 0, + upper_bound_rows: left_stats + .approx_stats + .upper_bound_rows + .and_then(|l| right_stats.approx_stats.upper_bound_rows.map(|r| l.max(r))), + lower_bound_bytes: 0, + upper_bound_bytes: left_stats + .approx_stats + .upper_bound_bytes + .and_then(|l| right_stats.approx_stats.upper_bound_bytes.map(|r| l.max(r))), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); @@ -320,6 +369,9 @@ impl Join { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/limit.rs b/src/daft-logical-plan/src/ops/limit.rs index 4d91ee4a84..fdb2ecab7c 100644 --- a/src/daft-logical-plan/src/ops/limit.rs +++ b/src/daft-logical-plan/src/ops/limit.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use crate::LogicalPlan; +use crate::{ + stats::{ApproxStats, PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Limit { @@ -11,6 +14,7 @@ pub struct Limit { // Whether to send tasks in waves (maximize throughput) or // eagerly one-at-a-time (maximize time-to-first-result) pub eager: bool, + pub stats_state: StatsState, } impl Limit { @@ -19,6 +23,46 @@ impl Limit { input, limit, eager, + stats_state: StatsState::NotMaterialized, } } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let limit = self.limit as usize; + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / input_stats.approx_stats.lower_bound_rows.max(1); + let est_bytes_per_row_upper = + input_stats + .approx_stats + .upper_bound_bytes + .and_then(|bytes| { + input_stats + .approx_stats + .upper_bound_rows + .map(|rows| bytes / rows.max(1)) + }); + let new_lower_rows = input_stats.approx_stats.lower_bound_rows.min(limit); + let new_upper_rows = input_stats + .approx_stats + .upper_bound_rows + .map(|ub| ub.min(limit)) + .unwrap_or(limit); + let approx_stats = ApproxStats { + lower_bound_rows: new_lower_rows, + upper_bound_rows: Some(new_upper_rows), + lower_bound_bytes: new_lower_rows * est_bytes_per_row_lower, + upper_bound_bytes: est_bytes_per_row_upper.map(|x| x * new_upper_rows), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Limit: {}", self.limit)]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res + } } diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index 1e798f4ff3..9be863a686 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -2,13 +2,14 @@ use std::sync::Arc; use daft_core::prelude::*; -use crate::LogicalPlan; +use crate::{stats::StatsState, LogicalPlan}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] pub struct MonotonicallyIncreasingId { pub input: Arc, pub schema: Arc, pub column_name: String, + pub stats_state: StatsState, } impl MonotonicallyIncreasingId { @@ -28,6 +29,22 @@ impl MonotonicallyIncreasingId { input, schema: Arc::new(schema_with_id), column_name: column_name.to_string(), + stats_state: StatsState::NotMaterialized, } } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("MonotonicallyIncreasingId")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res + } } diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index da204fdb34..57ee3bb1c5 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -9,6 +9,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::StatsState, LogicalPlan, }; @@ -21,6 +22,7 @@ pub struct Pivot { pub aggregation: AggExpr, pub names: Vec, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Pivot { @@ -78,9 +80,17 @@ impl Pivot { aggregation: agg_expr.clone(), names, output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): Pivoting does affect cardinality, but for now we keep the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push("Pivot:".to_string()); @@ -95,6 +105,9 @@ impl Pivot { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 40f0565102..dc165d5c5c 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -9,6 +9,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Result}, + stats::StatsState, LogicalPlan, }; @@ -18,6 +19,7 @@ pub struct Project { pub input: Arc, pub projection: Vec, pub projected_schema: SchemaRef, + pub stats_state: StatsState, } impl Project { @@ -38,8 +40,10 @@ impl Project { input: factored_input, projection: factored_projection, projected_schema, + stats_state: StatsState::NotMaterialized, }) } + /// Create a new Projection using the specified output schema pub(crate) fn new_from_schema(input: Arc, schema: SchemaRef) -> Result { let expr: Vec = schema @@ -50,11 +54,22 @@ impl Project { Self::try_new(input, expr) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { - vec![format!( + let mut res = vec![format!( "Project: {}", self.projection.iter().map(|e| e.to_string()).join(", ") - )] + )]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } fn try_factor_subexpressions( diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index 1dce616d62..ac12970c49 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -5,6 +5,7 @@ use daft_dsl::ExprResolver; use crate::{ partitioning::{HashRepartitionConfig, RepartitionSpec}, + stats::StatsState, LogicalPlan, }; @@ -13,6 +14,7 @@ pub struct Repartition { // Upstream node. pub input: Arc, pub repartition_spec: RepartitionSpec, + pub stats_state: StatsState, } impl Repartition { @@ -36,9 +38,17 @@ impl Repartition { Ok(Self { input, repartition_spec, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Repartitioning does not affect cardinality. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -46,6 +56,9 @@ impl Repartition { self.repartition_spec.var_name(), )); res.extend(self.repartition_spec.multiline_display()); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/sample.rs b/src/daft-logical-plan/src/ops/sample.rs index 9d96594666..d11cc5d1d6 100644 --- a/src/daft-logical-plan/src/ops/sample.rs +++ b/src/daft-logical-plan/src/ops/sample.rs @@ -3,7 +3,10 @@ use std::{ sync::Arc, }; -use crate::LogicalPlan; +use crate::{ + stats::{PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq)] pub struct Sample { @@ -12,6 +15,7 @@ pub struct Sample { pub fraction: f64, pub with_replacement: bool, pub seed: Option, + pub stats_state: StatsState, } impl Eq for Sample {} @@ -44,14 +48,28 @@ impl Sample { fraction, with_replacement, seed, + stats_state: StatsState::NotMaterialized, } } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let approx_stats = input_stats + .approx_stats + .apply(|v| ((v as f64) * self.fraction) as usize); + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Sample: {}", self.fraction)); res.push(format!("With replacement = {}", self.with_replacement)); res.push(format!("Seed = {:?}", self.seed)); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 017104f226..65ceb807b8 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -183,10 +183,7 @@ impl Union { (self.lhs.clone(), self.rhs.clone()) }; // we don't want to use `try_new` as we have already checked the schema - let concat = LogicalPlan::Concat(Concat { - input: lhs, - other: rhs, - }); + let concat = LogicalPlan::Concat(Concat::new(lhs, rhs)); if self.is_all { Ok(concat) } else { diff --git a/src/daft-logical-plan/src/ops/sink.rs b/src/daft-logical-plan/src/ops/sink.rs index 1370ab91f4..e5eb9f3f2e 100644 --- a/src/daft-logical-plan/src/ops/sink.rs +++ b/src/daft-logical-plan/src/ops/sink.rs @@ -6,7 +6,11 @@ use daft_dsl::ExprResolver; #[cfg(feature = "python")] use crate::sink_info::CatalogType; -use crate::{sink_info::SinkInfo, LogicalPlan, OutputFileInfo}; +use crate::{ + sink_info::SinkInfo, + stats::{PlanStats, StatsState}, + LogicalPlan, OutputFileInfo, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Sink { @@ -15,6 +19,7 @@ pub struct Sink { pub schema: SchemaRef, /// Information about the sink data location. pub sink_info: Arc, + pub stats_state: StatsState, } impl Sink { @@ -82,9 +87,17 @@ impl Sink { input, schema, sink_info, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Post-write DataFrame will contain paths to files that were written. + // TODO(desmond): Estimate output size via root directory and estimates for # of partitions given partitioning column. + self.stats_state = StatsState::Materialized(PlanStats::empty().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; @@ -110,6 +123,9 @@ impl Sink { }, } res.push(format!("Output schema = {}", self.schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 85cd8c2a64..9c2cd046fd 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -6,7 +6,7 @@ use daft_dsl::{ExprRef, ExprResolver}; use itertools::Itertools; use snafu::ResultExt; -use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; +use crate::{logical_plan, logical_plan::CreationSnafu, stats::StatsState, LogicalPlan}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Sort { @@ -15,6 +15,7 @@ pub struct Sort { pub sort_by: Vec, pub descending: Vec, pub nulls_first: Vec, + pub stats_state: StatsState, } impl Sort { @@ -54,9 +55,17 @@ impl Sort { sort_by, descending, nulls_first, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Sorting does not affect cardinality. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; // Must have at least one expression to sort by. @@ -76,6 +85,9 @@ impl Sort { }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/source.rs b/src/daft-logical-plan/src/ops/source.rs index 111575f411..4044e08c72 100644 --- a/src/daft-logical-plan/src/ops/source.rs +++ b/src/daft-logical-plan/src/ops/source.rs @@ -1,9 +1,13 @@ use std::sync::Arc; -use common_scan_info::PhysicalScanInfo; +use common_error::DaftResult; +use common_scan_info::{PhysicalScanInfo, ScanState}; use daft_schema::schema::SchemaRef; -use crate::source_info::{InMemoryInfo, PlaceHolderInfo, SourceInfo}; +use crate::{ + source_info::{InMemoryInfo, PlaceHolderInfo, SourceInfo}, + stats::{ApproxStats, PlanStats, StatsState}, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Source { @@ -13,6 +17,7 @@ pub struct Source { /// Information about the source data location. pub source_info: Arc, + pub stats_state: StatsState, } impl Source { @@ -20,21 +25,92 @@ impl Source { Self { output_schema, source_info, + stats_state: StatsState::NotMaterialized, } } + // Helper method that converts the ScanOperatorRef inside a Source node's PhysicalScanInfo into scan tasks. + // Should only be called if a Source node's source info contains PhysicalScanInfo. The PhysicalScanInfo + // should also hold a ScanState::Operator and not a ScanState::Tasks (which would indicate that we're + // materializing this physical scan node multiple times). + pub(crate) fn build_materialized_scan_source(mut self) -> DaftResult { + let new_physical_scan_info = match Arc::unwrap_or_clone(self.source_info) { + SourceInfo::Physical(mut physical_scan_info) => { + let scan_tasks = match &physical_scan_info.scan_state { + ScanState::Operator(scan_op) => scan_op + .0 + .to_scan_tasks(physical_scan_info.pushdowns.clone())?, + ScanState::Tasks(_) => { + panic!("Physical scan nodes are being materialized more than once"); + } + }; + physical_scan_info.scan_state = ScanState::Tasks(Arc::new(scan_tasks)); + physical_scan_info + } + _ => panic!("Only unmaterialized physical scan nodes can be materialized"), + }; + self.source_info = Arc::new(SourceInfo::Physical(new_physical_scan_info)); + Ok(self) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + let approx_stats = match &*self.source_info { + SourceInfo::InMemory(InMemoryInfo { + size_bytes, + num_rows, + .. + }) => ApproxStats { + lower_bound_rows: *num_rows, + upper_bound_rows: Some(*num_rows), + lower_bound_bytes: *size_bytes, + upper_bound_bytes: Some(*size_bytes), + }, + SourceInfo::Physical(physical_scan_info) => match &physical_scan_info.scan_state { + ScanState::Operator(_) => { + panic!("Scan nodes should be materialized before stats are materialized") + } + ScanState::Tasks(scan_tasks) => { + let mut approx_stats = ApproxStats::empty(); + for st in scan_tasks.iter() { + approx_stats.lower_bound_rows += st.num_rows().unwrap_or(0); + let in_memory_size = st.estimate_in_memory_size_bytes(None); + approx_stats.lower_bound_bytes += in_memory_size.unwrap_or(0); + if let Some(st_ub) = st.upper_bound_rows() { + if let Some(ub) = approx_stats.upper_bound_rows { + approx_stats.upper_bound_rows = Some(ub + st_ub); + } else { + approx_stats.upper_bound_rows = st.upper_bound_rows(); + } + } + if let Some(st_ub) = in_memory_size { + if let Some(ub) = approx_stats.upper_bound_bytes { + approx_stats.upper_bound_bytes = Some(ub + st_ub); + } else { + approx_stats.upper_bound_bytes = in_memory_size; + } + } + } + approx_stats + } + }, + SourceInfo::PlaceHolder(_) => ApproxStats::empty(), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; match self.source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { source_schema, - scan_op, + scan_state, partitioning_keys, pushdowns, }) => { use itertools::Itertools; - res.extend(scan_op.0.multiline_display()); + res.extend(scan_state.multiline_display()); res.push(format!("File schema = {}", source_schema.short_string())); res.push(format!( @@ -61,6 +137,9 @@ impl Source { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index cec9cd1c00..46a7071bf5 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -8,6 +8,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -19,9 +20,30 @@ pub struct Unpivot { pub variable_name: String, pub value_name: String, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Unpivot { + pub(crate) fn new( + input: Arc, + ids: Vec, + values: Vec, + variable_name: String, + value_name: String, + output_schema: SchemaRef, + ) -> Self { + Self { + input, + ids, + values, + variable_name, + value_name, + output_schema, + stats_state: StatsState::NotMaterialized, + } + } + + // Similar to new, except that `try_new` is not given the output schema and instead extracts it. pub(crate) fn try_new( input: Arc, ids: Vec, @@ -71,9 +93,26 @@ impl Unpivot { variable_name: variable_name.to_string(), value_name: value_name.to_string(), output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let num_values = self.values.len(); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows * num_values, + upper_bound_rows: input_stats + .approx_stats + .upper_bound_rows + .map(|v| v * num_values), + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -85,6 +124,9 @@ impl Unpivot { self.ids.iter().map(|e| e.to_string()).join(", ") )); res.push(format!("Schema = {}", self.output_schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 084018522a..61a3ff314e 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -6,8 +6,8 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, EliminateCrossJoin, LiftProjectFromAgg, OptimizerRule, PushDownFilter, - PushDownLimit, PushDownProjection, SplitActorPoolProjects, + DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, + OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SplitActorPoolProjects, }, }; use crate::LogicalPlan; @@ -136,6 +136,18 @@ impl Optimizer { RuleExecutionStrategy::FixedPoint(Some(3)), )); + // --- Materialize scan nodes --- + rule_batches.push(RuleBatch::new( + vec![Box::new(MaterializeScans::new())], + RuleExecutionStrategy::Once, + )); + + // --- Enrich logical plan with stats --- + rule_batches.push(RuleBatch::new( + vec![Box::new(EnrichWithStats::new())], + RuleExecutionStrategy::Once, + )); + Self::with_rule_batches(rule_batches, config) } diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index c8e888fecf..e9e3a2e524 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -52,7 +52,9 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(filter.input.as_ref()) { return Ok(Transformed::no(Arc::new(LogicalPlan::Filter(filter)))); } - let Filter { input, predicate } = filter; + let Filter { + input, predicate, .. + } = filter; flatten_join_inputs( Arc::unwrap_or_clone(input), &mut possible_join_keys, @@ -306,16 +308,16 @@ fn find_inner_join( .non_distinct_union(right_input.schema().as_ref()); let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); - return Ok(LogicalPlan::Join(Join { - left: left_input, - right: right_input, - left_on: left_keys, - right_on: right_keys, - null_equals_nulls: None, - join_type: JoinType::Inner, - join_strategy: None, - output_schema: Arc::new(join_schema), - }) + return Ok(LogicalPlan::Join(Join::new( + left_input, + right_input, + left_keys, + right_keys, + None, + JoinType::Inner, + None, + Arc::new(join_schema), + )) .arced()); } } @@ -327,16 +329,16 @@ fn find_inner_join( .schema() .non_distinct_union(right.schema().as_ref()); - Ok(LogicalPlan::Join(Join { - left: left_input, + Ok(LogicalPlan::Join(Join::new( + left_input, right, - left_on: vec![], - right_on: vec![], - null_equals_nulls: None, - join_type: JoinType::Inner, - join_strategy: None, - output_schema: Arc::new(join_schema), - }) + vec![], + vec![], + None, + JoinType::Inner, + None, + Arc::new(join_schema), + )) .arced()) } @@ -449,14 +451,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -470,14 +472,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } diff --git a/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs new file mode 100644 index 0000000000..f6bc338ab6 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs @@ -0,0 +1,27 @@ +#[derive(Default, Debug)] +pub struct EnrichWithStats {} + +impl EnrichWithStats { + pub fn new() -> Self { + Self {} + } +} +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; + +use super::OptimizerRule; +use crate::LogicalPlan; + +// Add stats to all logical plan nodes in a bottom up fashion. +// All scan nodes MUST be materialized before stats are enriched. +impl OptimizerRule for EnrichWithStats { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_up(|node: Arc| { + Ok(Transformed::yes( + Arc::unwrap_or_clone(node).with_materialized_stats().into(), + )) + }) + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs new file mode 100644 index 0000000000..4b9f9707ed --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs @@ -0,0 +1,47 @@ +#[derive(Default, Debug)] +pub struct MaterializeScans {} + +impl MaterializeScans { + pub fn new() -> Self { + Self {} + } +} +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; + +use super::OptimizerRule; +use crate::{LogicalPlan, SourceInfo}; + +// Materialize scan tasks from scan operators for all physical scans. +impl OptimizerRule for MaterializeScans { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_up(|node| self.try_optimize_node(node)) + } +} + +impl MaterializeScans { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { + match &*plan { + LogicalPlan::Source(source) => match &*source.source_info { + SourceInfo::Physical(_) => { + let source_plan = Arc::unwrap_or_clone(plan); + if let LogicalPlan::Source(source) = source_plan { + Ok(Transformed::yes( + source.build_materialized_scan_source()?.into(), + )) + } else { + unreachable!("This logical plan was already matched as a Source node") + } + } + _ => Ok(Transformed::no(plan)), + }, + _ => Ok(Transformed::no(plan)), + } + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index 78ce51533a..75e0f36c88 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -1,6 +1,8 @@ mod drop_repartition; mod eliminate_cross_join; +mod enrich_with_stats; mod lift_project_from_agg; +mod materialize_scans; mod push_down_filter; mod push_down_limit; mod push_down_projection; @@ -9,7 +11,9 @@ mod split_actor_pool_projects; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; +pub use enrich_with_stats::EnrichWithStats; pub use lift_project_from_agg::LiftProjectFromAgg; +pub use materialize_scans::MaterializeScans; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 2cf2ea14ad..9fa30ea8e5 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -113,7 +113,7 @@ impl PushDownFilter { needing_filter_op, } = rewrite_predicate_for_partitioning( &new_predicate, - external_info.scan_op.0.partitioning_keys(), + external_info.scan_state.get_scan_op().0.partitioning_keys(), )?; assert!( partition_only_filter.len() @@ -239,7 +239,7 @@ impl PushDownFilter { .into(); child_plan.with_new_children(&[new_filter]).into() } - LogicalPlan::Concat(Concat { input, other }) => { + LogicalPlan::Concat(Concat { input, other, .. }) => { // Push filter into each side of the concat. let new_input: LogicalPlan = Filter::try_new(input.clone(), filter.predicate.clone())?.into(); diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs index b32ab85fbb..e79879604d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs @@ -37,6 +37,7 @@ impl PushDownLimit { input, limit, eager, + .. }) => { let limit = *limit as usize; match input.as_ref() { @@ -74,7 +75,12 @@ impl PushDownLimit { SourceInfo::Physical(new_external_info).into(), )) .into(); - let out_plan = if external_info.scan_op.0.can_absorb_limit() { + let out_plan = if external_info + .scan_state + .get_scan_op() + .0 + .can_absorb_limit() + { new_source } else { plan.with_new_children(&[new_source]).into() @@ -93,6 +99,7 @@ impl PushDownLimit { input, limit: child_limit, eager: child_eagar, + .. }) => { let new_limit = limit.min(*child_limit as usize); let new_eager = eager | child_eagar; diff --git a/src/daft-logical-plan/src/source_info/mod.rs b/src/daft-logical-plan/src/source_info/mod.rs index 07ae94b841..11122464a0 100644 --- a/src/daft-logical-plan/src/source_info/mod.rs +++ b/src/daft-logical-plan/src/source_info/mod.rs @@ -16,7 +16,7 @@ use { use crate::partitioning::ClusteringSpecRef; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SourceInfo { InMemory(InMemoryInfo), Physical(PhysicalScanInfo), @@ -78,7 +78,7 @@ impl Hash for InMemoryInfo { static PLACEHOLDER_ID_COUNTER: AtomicUsize = AtomicUsize::new(0); -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct PlaceHolderInfo { pub source_schema: SchemaRef, pub clustering_spec: ClusteringSpecRef, diff --git a/src/daft-logical-plan/src/stats.rs b/src/daft-logical-plan/src/stats.rs new file mode 100644 index 0000000000..22c3f85198 --- /dev/null +++ b/src/daft-logical-plan/src/stats.rs @@ -0,0 +1,146 @@ +use std::{fmt::Display, hash::Hash, ops::Deref}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)] +pub enum StatsState { + Materialized(AlwaysSame), + NotMaterialized, +} + +impl StatsState { + pub fn materialized_stats(&self) -> &PlanStats { + match self { + Self::Materialized(stats) => stats, + Self::NotMaterialized => panic!("Tried to get unmaterialized stats"), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PlanStats { + // Currently we're only putting cardinality stats in the plan stats. + // In the future we want to start including column stats, including min, max, NDVs, etc. + pub approx_stats: ApproxStats, +} + +impl PlanStats { + pub fn new(approx_stats: ApproxStats) -> Self { + Self { approx_stats } + } + + pub fn empty() -> Self { + Self { + approx_stats: ApproxStats::empty(), + } + } +} + +impl Default for PlanStats { + fn default() -> Self { + Self::empty() + } +} + +impl Display for PlanStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{ Lower bound rows = {}, Upper bound rows = {}, Lower bound bytes = {}, Upper bound bytes = {} }}", + self.approx_stats.lower_bound_rows, + self.approx_stats.upper_bound_rows.map_or("None".to_string(), |v| v.to_string()), + self.approx_stats.lower_bound_bytes, + self.approx_stats.upper_bound_bytes.map_or("None".to_string(), |v| v.to_string()), + ) + } +} + +// We implement PartialEq, Eq, and Hash for AlwaysSame, then add PlanStats to LogicalPlans wrapped by AlwaysSame. +// This allows all PlanStats to be considered equal, so that logical/physical plans that are enriched with +// stats can easily implement PartialEq, Eq, and Hash in a way that ignores PlanStats when considering equality. + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AlwaysSame(T); + +impl Deref for AlwaysSame { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Hash for AlwaysSame { + #[inline] + fn hash(&self, _state: &mut H) { + // Add nothing to hash state since all AlwaysSame should hash the same. + } +} + +impl Eq for AlwaysSame {} + +impl PartialEq for AlwaysSame { + #[inline] + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl From for AlwaysSame { + #[inline] + fn from(value: T) -> Self { + Self(value) + } +} + +impl Display for AlwaysSame { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct ApproxStats { + pub lower_bound_rows: usize, + pub upper_bound_rows: Option, + pub lower_bound_bytes: usize, + pub upper_bound_bytes: Option, +} + +impl ApproxStats { + pub fn empty() -> Self { + Self { + lower_bound_rows: 0, + upper_bound_rows: None, + lower_bound_bytes: 0, + upper_bound_bytes: None, + } + } + pub fn apply usize>(&self, f: F) -> Self { + Self { + lower_bound_rows: f(self.lower_bound_rows), + upper_bound_rows: self.upper_bound_rows.map(&f), + lower_bound_bytes: f(self.lower_bound_rows), + upper_bound_bytes: self.upper_bound_bytes.map(&f), + } + } +} + +use std::ops::Add; +impl Add for &ApproxStats { + type Output = ApproxStats; + fn add(self, rhs: Self) -> Self::Output { + ApproxStats { + lower_bound_rows: self.lower_bound_rows + rhs.lower_bound_rows, + upper_bound_rows: self + .upper_bound_rows + .and_then(|l_ub| rhs.upper_bound_rows.map(|v| v + l_ub)), + lower_bound_bytes: self.lower_bound_bytes + rhs.lower_bound_bytes, + upper_bound_bytes: self + .upper_bound_bytes + .and_then(|l_ub| rhs.upper_bound_bytes.map(|v| v + l_ub)), + } + } +} diff --git a/src/daft-parquet/src/metadata.rs b/src/daft-parquet/src/metadata.rs index 32c1090ddd..aab541c04f 100644 --- a/src/daft-parquet/src/metadata.rs +++ b/src/daft-parquet/src/metadata.rs @@ -176,8 +176,8 @@ fn apply_field_ids_to_parquet_file_metadata( let new_row_groups = file_metadata .row_groups - .into_values() - .map(|rg| { + .iter() + .map(|(_, rg)| { let new_columns = rg .columns() .iter() diff --git a/src/daft-physical-plan/src/ops/scan.rs b/src/daft-physical-plan/src/ops/scan.rs index 6c8472beda..6333b05100 100644 --- a/src/daft-physical-plan/src/ops/scan.rs +++ b/src/daft-physical-plan/src/ops/scan.rs @@ -6,15 +6,15 @@ use common_scan_info::ScanTaskLikeRef; use daft_logical_plan::partitioning::ClusteringSpec; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct TabularScan { - pub scan_tasks: Vec, + pub scan_tasks: Arc>, pub clustering_spec: Arc, } impl TabularScan { pub(crate) fn new( - scan_tasks: Vec, + scan_tasks: Arc>, clustering_spec: Arc, ) -> Self { Self { @@ -100,7 +100,7 @@ Clustering spec = {{ {clustering_spec} }} let mut s = base_display(self); writeln!(s, "Scan Tasks: [").unwrap(); - for st in &self.scan_tasks { + for st in self.scan_tasks.iter() { writeln!(s, "{}", st.as_ref().display_as(DisplayLevel::Verbose)).unwrap(); } s diff --git a/src/daft-physical-plan/src/physical_planner/planner.rs b/src/daft-physical-plan/src/physical_planner/planner.rs index f5c146cd2e..23ad09376d 100644 --- a/src/daft-physical-plan/src/physical_planner/planner.rs +++ b/src/daft-physical-plan/src/physical_planner/planner.rs @@ -215,6 +215,7 @@ impl TreeNodeRewriter for ReplacePlaceholdersWithMaterializedResult { LogicalPlan::Source(Source { output_schema: _, source_info, + .. }) => match source_info.as_ref() { SourceInfo::PlaceHolder(phi) => { assert!(self.mat_results.is_some()); @@ -226,10 +227,10 @@ impl TreeNodeRewriter for ReplacePlaceholdersWithMaterializedResult { mat_results.in_memory_info.clustering_spec = Some(phi.clustering_spec.clone()); mat_results.in_memory_info.source_schema = phi.source_schema.clone(); - let new_source_node = LogicalPlan::Source(Source { - output_schema: mat_results.in_memory_info.source_schema.clone(), - source_info: SourceInfo::InMemory(mat_results.in_memory_info).into(), - }) + let new_source_node = LogicalPlan::Source(Source::new( + mat_results.in_memory_info.source_schema.clone(), + SourceInfo::InMemory(mat_results.in_memory_info).into(), + )) .arced(); Ok(Transformed::new( new_source_node, diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 9bffaef97b..ec4e4a1985 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -7,7 +7,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use common_file_formats::FileFormat; -use common_scan_info::PhysicalScanInfo; +use common_scan_info::{PhysicalScanInfo, ScanState, SPLIT_AND_MERGE_PASS}; use daft_core::prelude::*; use daft_dsl::{ col, functions::agg::merge_mean, is_partition_compatible, AggExpr, ApproxPercentileParams, @@ -38,15 +38,22 @@ pub(super) fn translate_single_logical_node( physical_children: &mut Vec, cfg: &DaftExecutionConfig, ) -> DaftResult { - match logical_plan { + let physical_plan = match logical_plan { LogicalPlan::Source(Source { source_info, .. }) => match source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { pushdowns, - scan_op, + scan_state, source_schema, .. }) => { - let scan_tasks = scan_op.0.to_scan_tasks(pushdowns.clone(), Some(cfg))?; + let scan_tasks = { + match scan_state { + ScanState::Operator(scan_op) => { + Arc::new(scan_op.0.to_scan_tasks(pushdowns.clone())?) + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + } + }; if scan_tasks.is_empty() { let clustering_spec = @@ -58,6 +65,14 @@ pub(super) fn translate_single_logical_node( )) .arced()) } else { + // Perform scan task splitting and merging. + let scan_tasks = if let Some(split_and_merge_pass) = SPLIT_AND_MERGE_PASS.get() + { + split_and_merge_pass(scan_tasks, pushdowns, cfg)? + } else { + scan_tasks + }; + let clustering_spec = Arc::new(ClusteringSpec::Unknown( UnknownClusteringConfig::new(scan_tasks.len()), )); @@ -205,7 +220,7 @@ pub(super) fn translate_single_logical_node( }; Ok(repartitioned_plan.arced()) } - LogicalPlan::Distinct(LogicalDistinct { input }) => { + LogicalPlan::Distinct(LogicalDistinct { input, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); let col_exprs = input .schema() @@ -756,7 +771,12 @@ pub(super) fn translate_single_logical_node( LogicalPlan::Union(_) => Err(DaftError::InternalError( "Union should already be optimized away".to_string(), )), - } + }?; + // TODO(desmond): We can't perform this check for now because ScanTasks currently provide + // different size estimations depending on when the approximation is computed. Once we fix + // this, we can add back in the assertion here. + // debug_assert!(logical_plan.get_stats().approx_stats == physical_plan.approximate_stats()); + Ok(physical_plan) } pub fn extract_agg_expr(expr: &ExprRef) -> DaftResult { diff --git a/src/daft-physical-plan/src/plan.rs b/src/daft-physical-plan/src/plan.rs index 740905b6e8..acf456f14d 100644 --- a/src/daft-physical-plan/src/plan.rs +++ b/src/daft-physical-plan/src/plan.rs @@ -1,8 +1,11 @@ -use std::{cmp::max, collections::HashSet, ops::Add, sync::Arc}; +use std::{cmp::max, collections::HashSet, sync::Arc}; use common_display::ascii::AsciiTreeDisplay; -use daft_logical_plan::partitioning::{ - ClusteringSpec, HashClusteringConfig, RangeClusteringConfig, UnknownClusteringConfig, +use daft_logical_plan::{ + partitioning::{ + ClusteringSpec, HashClusteringConfig, RangeClusteringConfig, UnknownClusteringConfig, + }, + stats::ApproxStats, }; use serde::{Deserialize, Serialize}; @@ -43,48 +46,6 @@ pub enum PhysicalPlan { LanceWrite(LanceWrite), } -pub struct ApproxStats { - pub lower_bound_rows: usize, - pub upper_bound_rows: Option, - pub lower_bound_bytes: usize, - pub upper_bound_bytes: Option, -} - -impl ApproxStats { - fn empty() -> Self { - Self { - lower_bound_rows: 0, - upper_bound_rows: None, - lower_bound_bytes: 0, - upper_bound_bytes: None, - } - } - fn apply usize>(&self, f: F) -> Self { - Self { - lower_bound_rows: f(self.lower_bound_rows), - upper_bound_rows: self.upper_bound_rows.map(&f), - lower_bound_bytes: f(self.lower_bound_rows), - upper_bound_bytes: self.upper_bound_bytes.map(&f), - } - } -} - -impl Add for &ApproxStats { - type Output = ApproxStats; - fn add(self, rhs: Self) -> Self::Output { - ApproxStats { - lower_bound_rows: self.lower_bound_rows + rhs.lower_bound_rows, - upper_bound_rows: self - .upper_bound_rows - .and_then(|l_ub| rhs.upper_bound_rows.map(|v| v + l_ub)), - lower_bound_bytes: self.lower_bound_bytes + rhs.lower_bound_bytes, - upper_bound_bytes: self - .upper_bound_bytes - .and_then(|l_ub| rhs.upper_bound_bytes.map(|v| v + l_ub)), - } - } -} - impl PhysicalPlan { pub fn arced(self) -> PhysicalPlanRef { Arc::new(self) @@ -229,16 +190,24 @@ impl PhysicalPlan { }, Self::TabularScan(TabularScan { scan_tasks, .. }) => { let mut stats = ApproxStats::empty(); - for st in scan_tasks { + for st in scan_tasks.iter() { stats.lower_bound_rows += st.num_rows().unwrap_or(0); let in_memory_size = st.estimate_in_memory_size_bytes(None); stats.lower_bound_bytes += in_memory_size.unwrap_or(0); - stats.upper_bound_rows = stats - .upper_bound_rows - .and_then(|st_ub| st.upper_bound_rows().map(|ub| st_ub + ub)); - stats.upper_bound_bytes = stats - .upper_bound_bytes - .and_then(|st_ub| in_memory_size.map(|ub| st_ub + ub)); + if let Some(st_ub) = st.upper_bound_rows() { + if let Some(ub) = stats.upper_bound_rows { + stats.upper_bound_rows = Some(ub + st_ub); + } else { + stats.upper_bound_rows = st.upper_bound_rows(); + } + } + if let Some(st_ub) = in_memory_size { + if let Some(ub) = stats.upper_bound_bytes { + stats.upper_bound_bytes = Some(ub + st_ub); + } else { + stats.upper_bound_bytes = in_memory_size; + } + } } stats } diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 12eeb71f5b..49344c53ab 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -8,6 +8,7 @@ common-io-config = {path = "../common/io-config", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} common-runtime = {path = "../common/runtime", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} +ctor = "0.2.9" daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-decoding = {path = "../daft-decoding", default-features = false} diff --git a/src/daft-scan/src/anonymous.rs b/src/daft-scan/src/anonymous.rs index 2931996809..17f8c6574a 100644 --- a/src/daft-scan/src/anonymous.rs +++ b/src/daft-scan/src/anonymous.rs @@ -1,16 +1,11 @@ use std::sync::Arc; -use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; use common_scan_info::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeRef}; use daft_schema::schema::SchemaRef; -use crate::{ - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, - storage_config::StorageConfig, - ChunkSpec, DataSource, ScanTask, -}; +use crate::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask}; #[derive(Debug)] pub struct AnonymousScanOperator { files: Vec, @@ -74,11 +69,7 @@ impl ScanOperator for AnonymousScanOperator { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let files = self.files.clone(); let file_format_config = self.file_format_config.clone(); let schema = self.schema.clone(); @@ -95,10 +86,12 @@ impl ScanOperator for AnonymousScanOperator { }; // Create one ScanTask per file. - let mut scan_tasks: BoxScanTaskIter = - Box::new(files.into_iter().zip(row_groups).map(|(f, rg)| { + Ok(files + .into_iter() + .zip(row_groups) + .map(|(f, rg)| { let chunk_spec = rg.map(ChunkSpec::Parquet); - Ok(ScanTask::new( + Arc::new(ScanTask::new( vec![DataSource::File { path: f, chunk_spec, @@ -114,23 +107,9 @@ impl ScanOperator for AnonymousScanOperator { storage_config.clone(), pushdowns.clone(), None, - ) - .into()) - })); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - - scan_tasks - .map(|st| st.map(|task| task as Arc)) - .collect() + )) + }) + .map(|st| st as Arc) + .collect()) } } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 899e0ebc89..2f8d0f071f 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -1,6 +1,5 @@ use std::{sync::Arc, vec}; -use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use common_file_formats::{CsvSourceConfig, FileFormat, FileFormatConfig, ParquetSourceConfig}; use common_runtime::RuntimeRef; @@ -21,7 +20,6 @@ use snafu::Snafu; use crate::{ hive::{hive_partitions_to_fields, hive_partitions_to_series, parse_hive_partitioning}, - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask, }; @@ -355,11 +353,7 @@ impl ScanOperator for GlobScanOperator { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let (io_runtime, io_client) = self.storage_config.get_io_client_and_runtime()?; let io_stats = IOStatsContext::new(format!( "GlobScanOperator::to_scan_tasks for {:#?}", @@ -397,89 +391,79 @@ impl ScanOperator for GlobScanOperator { .collect(); let partition_schema = Schema::new(partition_fields)?; // Create one ScanTask per file. - let mut scan_tasks: BoxScanTaskIter = Box::new(files.enumerate().filter_map(|(idx, f)| { - let scan_task_result = (|| { - let FileMetadata { - filepath: path, - size: size_bytes, - .. - } = f?; - // Create partition values from hive partitions, if any. - let mut partition_values = if hive_partitioning { - let hive_partitions = parse_hive_partitioning(&path)?; - hive_partitions_to_series(&hive_partitions, &partition_schema)? - } else { - vec![] - }; - // Extend partition values based on whether a file_path_column is set (this column is inherently a partition). - if let Some(fp_col) = &file_path_column { - let trimmed = path.trim_start_matches("file://"); - let file_paths_column_series = - Utf8Array::from_iter(fp_col, std::iter::once(Some(trimmed))).into_series(); - partition_values.push(file_paths_column_series); - } - let (partition_spec, generated_fields) = if !partition_values.is_empty() { - let partition_values_table = Table::from_nonempty_columns(partition_values)?; - // If there are partition values, evaluate them against partition filters, if any. - if let Some(partition_filters) = &pushdowns.partition_filters { - let filter_result = - partition_values_table.filter(&[partition_filters.clone()])?; - if filter_result.is_empty() { - // Skip the current file since it does not satisfy the partition filters. - return Ok(None); - } + files + .enumerate() + .filter_map(|(idx, f)| { + let scan_task_result = (|| { + let FileMetadata { + filepath: path, + size: size_bytes, + .. + } = f?; + // Create partition values from hive partitions, if any. + let mut partition_values = if hive_partitioning { + let hive_partitions = parse_hive_partitioning(&path)?; + hive_partitions_to_series(&hive_partitions, &partition_schema)? + } else { + vec![] + }; + // Extend partition values based on whether a file_path_column is set (this column is inherently a partition). + if let Some(fp_col) = &file_path_column { + let trimmed = path.trim_start_matches("file://"); + let file_paths_column_series = + Utf8Array::from_iter(fp_col, std::iter::once(Some(trimmed))) + .into_series(); + partition_values.push(file_paths_column_series); } - let generated_fields = partition_values_table.schema.clone(); - let partition_spec = PartitionSpec { - keys: partition_values_table, + let (partition_spec, generated_fields) = if !partition_values.is_empty() { + let partition_values_table = + Table::from_nonempty_columns(partition_values)?; + // If there are partition values, evaluate them against partition filters, if any. + if let Some(partition_filters) = &pushdowns.partition_filters { + let filter_result = + partition_values_table.filter(&[partition_filters.clone()])?; + if filter_result.is_empty() { + // Skip the current file since it does not satisfy the partition filters. + return Ok(None); + } + } + let generated_fields = partition_values_table.schema.clone(); + let partition_spec = PartitionSpec { + keys: partition_values_table, + }; + (Some(partition_spec), Some(generated_fields)) + } else { + (None, None) }; - (Some(partition_spec), Some(generated_fields)) - } else { - (None, None) - }; - let row_group = row_groups - .as_ref() - .and_then(|rgs| rgs.get(idx).cloned()) - .flatten(); - let chunk_spec = row_group.map(ChunkSpec::Parquet); - Ok(Some(ScanTask::new( - vec![DataSource::File { - path, - chunk_spec, - size_bytes, - iceberg_delete_files: None, - metadata: None, - partition_spec, - statistics: None, - parquet_metadata: None, - }], - file_format_config.clone(), - schema.clone(), - storage_config.clone(), - pushdowns.clone(), - generated_fields, - ))) - })(); - match scan_task_result { - Ok(Some(scan_task)) => Some(Ok(scan_task.into())), - Ok(None) => None, - Err(e) => Some(Err(e)), - } - })); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - - scan_tasks - .map(|st| st.map(|task| task as Arc)) + let row_group = row_groups + .as_ref() + .and_then(|rgs| rgs.get(idx).cloned()) + .flatten(); + let chunk_spec = row_group.map(ChunkSpec::Parquet); + Ok(Some(ScanTask::new( + vec![DataSource::File { + path, + chunk_spec, + size_bytes, + iceberg_delete_files: None, + metadata: None, + partition_spec, + statistics: None, + parquet_metadata: None, + }], + file_format_config.clone(), + schema.clone(), + storage_config.clone(), + pushdowns.clone(), + generated_fields, + ))) + })(); + match scan_task_result { + Ok(Some(scan_task)) => Some(Ok(Arc::new(scan_task) as Arc)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }) .collect() } } diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 788c0f3b60..efae6d250c 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -1,6 +1,12 @@ #![feature(if_let_guard)] #![feature(let_chains)] -use std::{any::Any, borrow::Cow, fmt::Debug, sync::Arc}; +use std::{ + any::Any, + borrow::Cow, + fmt::Debug, + hash::{Hash, Hasher}, + sync::Arc, +}; use common_display::DisplayAs; use common_error::DaftError; @@ -100,7 +106,7 @@ impl From for pyo3::PyErr { } /// Specification of a subset of a file to be read. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum ChunkSpec { /// Selection of Parquet row groups. Parquet(Vec), @@ -149,6 +155,63 @@ pub enum DataSource { }, } +impl Hash for DataSource { + fn hash(&self, state: &mut H) { + // Hash everything except for cached parquet metadata. + match self { + Self::File { + path, + chunk_spec, + size_bytes, + iceberg_delete_files, + metadata, + partition_spec, + statistics, + .. + } => { + path.hash(state); + if let Some(chunk_spec) = chunk_spec { + chunk_spec.hash(state); + } + size_bytes.hash(state); + iceberg_delete_files.hash(state); + metadata.hash(state); + partition_spec.hash(state); + statistics.hash(state); + } + Self::Database { + path, + size_bytes, + metadata, + statistics, + } => { + path.hash(state); + size_bytes.hash(state); + metadata.hash(state); + statistics.hash(state); + } + #[cfg(feature = "python")] + Self::PythonFactoryFunction { + module, + func_name, + func_args, + size_bytes, + metadata, + statistics, + partition_spec, + } => { + module.hash(state); + func_name.hash(state); + func_args.hash(state); + size_bytes.hash(state); + metadata.hash(state); + statistics.hash(state); + partition_spec.hash(state); + } + } + } +} + impl DataSource { #[must_use] pub fn get_path(&self) -> &str { @@ -349,7 +412,7 @@ impl DisplayAs for DataSource { } } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Hash)] pub struct ScanTask { pub sources: Vec, @@ -388,6 +451,10 @@ impl ScanTaskLike for ScanTask { .map_or(false, |a| a == self) } + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state); + } + fn materialized_schema(&self) -> SchemaRef { self.materialized_schema() } diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index dd30b4541c..4c9da39372 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -1,3 +1,5 @@ +use std::hash::{Hash, Hasher}; + use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::{prelude::*, types::PyTuple}; use serde::{Deserialize, Serialize}; @@ -15,27 +17,48 @@ struct PyObjectSerializableWrapper( /// Python arguments to a Python function that produces Tables #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PythonTablesFactoryArgs(Vec); +pub struct PythonTablesFactoryArgs { + args: Vec, + hash: u64, +} + +impl Hash for PythonTablesFactoryArgs { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} impl PythonTablesFactoryArgs { pub fn new(args: Vec) -> Self { - Self(args.into_iter().map(PyObjectSerializableWrapper).collect()) + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + Python::with_gil(|py| { + for obj in &args { + // Only hash hashable PyObjects. + if let Ok(hash) = obj.bind(py).hash() { + hash.hash(&mut hasher); + } + } + }); + Self { + args: args.into_iter().map(PyObjectSerializableWrapper).collect(), + hash: hasher.finish(), + } } #[must_use] pub fn to_pytuple<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> { - pyo3::types::PyTuple::new_bound(py, self.0.iter().map(|x| x.0.bind(py))) + pyo3::types::PyTuple::new_bound(py, self.args.iter().map(|x| x.0.bind(py))) } } impl PartialEq for PythonTablesFactoryArgs { fn eq(&self, other: &Self) -> bool { - if self.0.len() != other.0.len() { + if self.args.len() != other.args.len() { return false; } - self.0 + self.args .iter() - .zip(other.0.iter()) + .zip(other.args.iter()) .all(|(s, o)| (s.0.as_ptr() as isize) == (o.0.as_ptr() as isize)) } } @@ -43,7 +66,7 @@ impl PartialEq for PythonTablesFactoryArgs { pub mod pylib { use std::sync::Arc; - use common_daft_config::{DaftExecutionConfig, PyDaftExecutionConfig}; + use common_daft_config::PyDaftExecutionConfig; use common_error::DaftResult; use common_file_formats::{python::PyFileFormatConfig, FileFormatConfig}; use common_py_serde::impl_bincode_py_state_serialization; @@ -66,7 +89,6 @@ pub mod pylib { use crate::{ anonymous::AnonymousScanOperator, glob::GlobScanOperator, - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, storage_config::{PyStorageConfig, PythonStorageConfig}, DataSource, ScanTask, }; @@ -248,11 +270,7 @@ pub mod pylib { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let scan_tasks = Python::with_gil(|py| { let pypd = PyPushdowns(pushdowns.clone().into()).into_py(py); let pyiter = @@ -269,20 +287,8 @@ pub mod pylib { ) })?; - let mut scan_tasks: BoxScanTaskIter = Box::new(scan_tasks.into_iter()); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - scan_tasks + .into_iter() .map(|st| st.map(|task| task as Arc)) .collect() } diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index c1c92b2483..3ee2a18ccd 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use common_daft_config::DaftExecutionConfig; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; +use common_scan_info::{ScanTaskLike, ScanTaskLikeRef, SPLIT_AND_MERGE_PASS}; use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; use parquet2::metadata::RowGroupList; @@ -308,3 +309,43 @@ pub(crate) fn split_by_row_groups( ) } } + +fn split_and_merge_pass( + scan_tasks: Arc>, + pushdowns: &Pushdowns, + cfg: &DaftExecutionConfig, +) -> DaftResult>> { + // Perform scan task splitting and merging if there are only ScanTasks (i.e. no DummyScanTasks). + if scan_tasks + .iter() + .all(|st| st.as_any().downcast_ref::().is_some()) + { + // TODO(desmond): Here we downcast Arc to Arc. ScanTask and DummyScanTask (test only) are + // the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging + // at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies. + let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| { + st.clone() + .as_any_arc() + .downcast::() + .map_err(|e| DaftError::TypeError(format!("Expected Arc, found {:?}", e))) + })); + let split_tasks = split_by_row_groups( + iter, + cfg.parquet_split_row_groups_max_files, + cfg.scan_tasks_min_size_bytes, + cfg.scan_tasks_max_size_bytes, + ); + let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg); + let scan_tasks: Vec> = merged_tasks + .map(|st| st.map(|task| task as Arc)) + .collect::>>()?; + Ok(Arc::new(scan_tasks)) + } else { + Ok(scan_tasks) + } +} + +#[ctor::ctor] +fn set_pass() { + let _ = SPLIT_AND_MERGE_PASS.set(&split_and_merge_pass); +} diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index c304cad020..485dcf0aeb 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -55,14 +55,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -76,14 +76,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -97,14 +97,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -191,7 +191,7 @@ mod tests { let sql = "select test as a from tbl1"; let plan = planner.plan_sql(sql).unwrap(); - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("test").alias("a")]) .unwrap() .build(); @@ -203,7 +203,7 @@ mod tests { let sql = "select test as a from tbl1 where test = 'a'"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .filter(col("test").eq(lit("a")))? .select(vec![col("test").alias("a")])? .build(); @@ -216,7 +216,7 @@ mod tests { let sql = "select test as a from tbl1 limit 10"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("test").alias("a")])? .limit(10, true)? .build(); @@ -230,7 +230,7 @@ mod tests { let sql = "select utf8 from tbl1 order by utf8 desc"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("utf8")])? .sort(vec![col("utf8")], vec![true], vec![true])? .build(); @@ -241,7 +241,7 @@ mod tests { #[rstest] fn test_cast(mut planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { - let builder = LogicalPlanBuilder::new(tbl_1, None); + let builder = LogicalPlanBuilder::from(tbl_1); let cases = vec![ ( "select bool::text from tbl1", @@ -285,7 +285,7 @@ mod tests { if null_equals_null { "<=>" } else { "=" } ); let plan = planner.plan_sql(&sql)?; - let expected = LogicalPlanBuilder::new(tbl_2, None) + let expected = LogicalPlanBuilder::from(tbl_2) .join_with_null_safe_equal( tbl_3, vec![col("id")], @@ -312,7 +312,7 @@ mod tests { let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0"; let plan = planner.plan_sql(&sql)?; - let expected = LogicalPlanBuilder::new(tbl_2, None) + let expected = LogicalPlanBuilder::from(tbl_2) .filter(col("val").gt(lit(0 as i64)))? .join_with_null_safe_equal( tbl_3, @@ -394,7 +394,7 @@ mod tests { let sql = "select max(i32) from tbl1"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .aggregate(vec![col("i32").max()], vec![])? .select(vec![col("i32")])? .build(); @@ -469,7 +469,7 @@ mod tests { field: Field::new("i32", DataType::Int32), depth: 1, })); - let subquery = LogicalPlanBuilder::new(tbl_2, None) + let subquery = LogicalPlanBuilder::from(tbl_2) .filter(col("id").eq(outer_col))? .aggregate(vec![col("id").max()], vec![])? .select(vec![col("id")])? @@ -477,7 +477,7 @@ mod tests { let subquery = Arc::new(Expr::Subquery(Subquery { plan: subquery })); - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .filter(col("i64").gt(subquery))? .select(vec![col("utf8")])? .build(); diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index 491c63ec40..f2e733815c 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -2,7 +2,10 @@ mod arithmetic; mod comparison; mod logical; -use std::string::FromUtf8Error; +use std::{ + hash::{Hash, Hasher}, + string::FromUtf8Error, +}; use daft_core::prelude::*; use snafu::{ResultExt, Snafu}; @@ -14,6 +17,24 @@ pub enum ColumnRangeStatistics { Loaded(Series, Series), } +impl Hash for ColumnRangeStatistics { + fn hash(&self, state: &mut H) { + match self { + Self::Missing => (), + Self::Loaded(l, u) => { + let lower_hashes = l + .hash(None) + .expect("Failed to hash lower column range statistics"); + lower_hashes.into_iter().for_each(|h| h.hash(state)); + let upper_hashes = u + .hash(None) + .expect("Failed to hash upper column range statistics"); + upper_hashes.into_iter().for_each(|h| h.hash(state)); + } + } + } +} + #[derive(PartialEq, Eq, Debug)] pub enum TruthValue { False, diff --git a/src/daft-stats/src/partition_spec.rs b/src/daft-stats/src/partition_spec.rs index 24834bf116..79fe4c21b4 100644 --- a/src/daft-stats/src/partition_spec.rs +++ b/src/daft-stats/src/partition_spec.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + hash::{Hash, Hasher}, +}; use daft_core::array::ops::{DaftCompare, DaftLogical}; use daft_dsl::{ExprRef, Literal}; @@ -71,3 +74,15 @@ impl PartialEq for PartitionSpec { } impl Eq for PartitionSpec {} + +// Manually implement Hash to ensure consistency with `PartialEq`. +impl Hash for PartitionSpec { + fn hash(&self, state: &mut H) { + self.keys.schema.hash(state); + + for column in &self.keys { + let column_hashes = column.hash(None).expect("Failed to hash column"); + column_hashes.into_iter().for_each(|h| h.hash(state)); + } + } +} diff --git a/src/daft-stats/src/table_metadata.rs b/src/daft-stats/src/table_metadata.rs index bcd76e96c4..4db512bf57 100644 --- a/src/daft-stats/src/table_metadata.rs +++ b/src/daft-stats/src/table_metadata.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct TableMetadata { pub length: usize, } diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index e0d91d24c6..5f6f32a5a8 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, fmt::Display, + hash::{Hash, Hasher}, ops::{BitAnd, BitOr, Not}, }; @@ -17,6 +18,15 @@ pub struct TableStatistics { pub columns: IndexMap, } +impl Hash for TableStatistics { + fn hash(&self, state: &mut H) { + for (key, value) in &self.columns { + key.hash(state); + value.hash(state); + } + } +} + impl TableStatistics { pub fn from_stats_table(table: &Table) -> DaftResult { // Assumed format is each column having 2 rows: diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index d48cb36d33..359758e802 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -6,6 +6,7 @@ use core::slice; use std::{ collections::{HashMap, HashSet}, fmt::{Display, Formatter, Result}, + hash::{Hash, Hasher}, }; use arrow2::array::Array; @@ -47,6 +48,17 @@ pub struct Table { num_rows: usize, } +impl Hash for Table { + fn hash(&self, state: &mut H) { + self.schema.hash(state); + for col in &self.columns { + let hashes = col.hash(None).expect("Failed to hash column"); + hashes.into_iter().for_each(|h| h.hash(state)); + } + self.num_rows.hash(state); + } +} + #[inline] fn _validate_schema(schema: &Schema, columns: &[Series]) -> DaftResult<()> { if schema.fields.len() != columns.len() {