Skip to content

Commit

Permalink
Fix recursive-protection feature flag (#13887)
Browse files Browse the repository at this point in the history
* Fix recursive-protection feature flag

* rename feature flag to be consistent

* Make default

* taplo format
  • Loading branch information
alamb authored Dec 24, 2024
1 parent 901a094 commit e99e02b
Show file tree
Hide file tree
Showing 22 changed files with 108 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Default features:
- `regex_expressions`: regular expression functions, such as `regexp_match`
- `unicode_expressions`: Include unicode aware functions such as `character_length`
- `unparser`: enables support to reverse LogicalPlans back into SQL
- `recursive-protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection.
- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection.

Optional features:

Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ datafusion = { path = "../datafusion/core", version = "43.0.0", features = [
"datetime_expressions",
"encoding_expressions",
"parquet",
"recursive_protection",
"regex_expressions",
"unicode_expressions",
"compression",
Expand Down
3 changes: 1 addition & 2 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ name = "datafusion_common"
path = "src/lib.rs"

[features]
default = ["recursive-protection"]
avro = ["apache-avro"]
backtrace = []
pyarrow = ["pyo3", "arrow/pyarrow", "parquet"]
force_hash_collisions = []
recursive-protection = ["dep:recursive"]
recursive_protection = ["dep:recursive"]

[dependencies]
ahash = { workspace = true }
Expand Down
14 changes: 7 additions & 7 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pub trait TreeNode: Sized {
/// TreeNodeVisitor::f_up(ChildNode2)
/// TreeNodeVisitor::f_up(ParentNode)
/// ```
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
Expand Down Expand Up @@ -174,7 +174,7 @@ pub trait TreeNode: Sized {
/// TreeNodeRewriter::f_up(ChildNode2)
/// TreeNodeRewriter::f_up(ParentNode)
/// ```
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
self,
rewriter: &mut R,
Expand All @@ -197,7 +197,7 @@ pub trait TreeNode: Sized {
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
node: &'n N,
f: &mut F,
Expand Down Expand Up @@ -232,7 +232,7 @@ pub trait TreeNode: Sized {
self,
mut f: F,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
Expand All @@ -256,7 +256,7 @@ pub trait TreeNode: Sized {
self,
mut f: F,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
Expand Down Expand Up @@ -371,7 +371,7 @@ pub trait TreeNode: Sized {
mut f_down: FD,
mut f_up: FU,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_down_up_impl<
N: TreeNode,
FD: FnMut(N) -> Result<Transformed<N>>,
Expand Down Expand Up @@ -2349,7 +2349,7 @@ pub(crate) mod tests {
Ok(())
}

#[cfg(feature = "recursive-protection")]
#[cfg(feature = "recursive_protection")]
#[test]
fn test_large_tree() {
let mut item = TestTreeNode::new_leaf("initial".to_string());
Expand Down
8 changes: 8 additions & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ default = [
"unicode_expressions",
"compression",
"parquet",
"recursive_protection",
]
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
Expand All @@ -69,6 +70,13 @@ pyarrow = ["datafusion-common/pyarrow", "parquet"]
regex_expressions = [
"datafusion-functions/regex_expressions",
]
recursive_protection = [
"datafusion-common/recursive_protection",
"datafusion-expr/recursive_protection",
"datafusion-optimizer/recursive_protection",
"datafusion-physical-optimizer/recursive_protection",
"datafusion-sql/recursive_protection",
]
serde = ["arrow-schema/serde"]
string_expressions = ["datafusion-functions/string_expressions"]
unicode_expressions = [
Expand Down
3 changes: 1 addition & 2 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ name = "datafusion_expr"
path = "src/lib.rs"

[features]
default = ["recursive-protection"]
recursive-protection = ["dep:recursive"]
recursive_protection = ["dep:recursive"]

[dependencies]
arrow = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl ExprSchemable for Expr {
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expand Down
12 changes: 6 additions & 6 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ impl LogicalPlan {

/// Visits a plan similarly to [`Self::visit`], including subqueries that
/// may appear in expressions such as `IN (SELECT ...)`.
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
pub fn visit_with_subqueries<V: for<'n> TreeNodeVisitor<'n, Node = Self>>(
&self,
visitor: &mut V,
Expand All @@ -687,7 +687,7 @@ impl LogicalPlan {
/// Similarly to [`Self::rewrite`], rewrites this node and its inputs using `f`,
/// including subqueries that may appear in expressions such as `IN (SELECT
/// ...)`.
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
pub fn rewrite_with_subqueries<R: TreeNodeRewriter<Node = Self>>(
self,
rewriter: &mut R,
Expand All @@ -706,7 +706,7 @@ impl LogicalPlan {
&self,
mut f: F,
) -> Result<TreeNodeRecursion> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn apply_with_subqueries_impl<
F: FnMut(&LogicalPlan) -> Result<TreeNodeRecursion>,
>(
Expand Down Expand Up @@ -741,7 +741,7 @@ impl LogicalPlan {
self,
mut f: F,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_down_with_subqueries_impl<
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
>(
Expand All @@ -766,7 +766,7 @@ impl LogicalPlan {
self,
mut f: F,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_up_with_subqueries_impl<
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
>(
Expand Down Expand Up @@ -794,7 +794,7 @@ impl LogicalPlan {
mut f_down: FD,
mut f_up: FU,
) -> Result<Transformed<Self>> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn transform_down_up_with_subqueries_impl<
FD: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
FU: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
Expand Down
3 changes: 1 addition & 2 deletions datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ name = "datafusion_optimizer"
path = "src/lib.rs"

[features]
default = ["recursive-protection"]
recursive-protection = ["dep:recursive"]
recursive_protection = ["dep:recursive"]

[dependencies]
arrow = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
}

// Recursively check the unsupported outer references in the sub query plan.
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> {
if !can_contain_outer_ref && inner_plan.contains_outer_reference() {
return plan_err!("Accessing outer reference columns is not allowed in the plan");
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ impl OptimizerRule for CommonSubexprEliminate {
None
}

#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn rewrite(
&self,
plan: LogicalPlan,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl OptimizerRule for EliminateCrossJoin {
true
}

#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn rewrite(
&self,
plan: LogicalPlan,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl OptimizerRule for OptimizeProjections {
/// columns.
/// - `Ok(None)`: Signal that the given logical plan did not require any change.
/// - `Err(error)`: An error occurred during the optimization process.
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn optimize_projections(
plan: LogicalPlan,
config: &dyn OptimizerConfig,
Expand Down
3 changes: 1 addition & 2 deletions datafusion/physical-optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ rust-version = { workspace = true }
workspace = true

[features]
default = ["recursive-protection"]
recursive-protection = ["dep:recursive"]
recursive_protection = ["dep:recursive"]

[dependencies]
arrow = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-optimizer/src/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl AggregateStatistics {
}

impl PhysicalOptimizerRule for AggregateStatistics {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ name = "datafusion_sql"
path = "src/lib.rs"

[features]
default = ["unicode_expressions", "unparser", "recursive-protection"]
default = ["unicode_expressions", "unparser"]
unicode_expressions = []
unparser = []
recursive-protection = ["dep:recursive"]
recursive_protection = ["dep:recursive"]

[dependencies]
arrow = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

/// Internal implementation. Use
/// [`Self::sql_expr_to_logical_expr`] to plan exprs.
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn sql_expr_to_logical_expr_internal(
&self,
sql: SQLExpr,
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ mod query;
mod relation;
mod select;
mod set_expr;
mod stack;
mod statement;
#[cfg(feature = "unparser")]
pub mod unparser;
Expand Down
10 changes: 6 additions & 4 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::sync::Arc;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use crate::stack::StackGuard;
use datafusion_common::{not_impl_err, Constraints, DFSchema, Result};
use datafusion_expr::expr::Sort;
use datafusion_expr::{
Expand Down Expand Up @@ -62,10 +63,11 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// The functions called from `set_expr_to_plan()` need more than 128KB
// stack in debug builds as investigated in:
// https://github.com/apache/datafusion/pull/13310#discussion_r1836813902
let min_stack_size = recursive::get_minimum_stack_size();
recursive::set_minimum_stack_size(256 * 1024);
let plan = self.set_expr_to_plan(other, planner_context)?;
recursive::set_minimum_stack_size(min_stack_size);
let plan = {
// scope for dropping _guard
let _guard = StackGuard::new(256 * 1024);
self.set_expr_to_plan(other, planner_context)
}?;
let oby_exprs = to_order_by_exprs(query.order_by)?;
let order_by_rex = self.order_by_to_sort_expr(
oby_exprs,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/set_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier};

impl<S: ContextProvider> SqlToRel<'_, S> {
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
pub(super) fn set_expr_to_plan(
&self,
set_expr: SetExpr,
Expand Down
63 changes: 63 additions & 0 deletions datafusion/sql/src/stack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

pub use inner::StackGuard;

/// A guard that sets the minimum stack size for the current thread to `min_stack_size` bytes.
#[cfg(feature = "recursive_protection")]
mod inner {
/// Sets the stack size to `min_stack_size` bytes on call to `new()` and
/// resets to the previous value when this structure is dropped.
pub struct StackGuard {
previous_stack_size: usize,
}

impl StackGuard {
/// Sets the stack size to `min_stack_size` bytes on call to `new()` and
/// resets to the previous value when this structure is dropped.
pub fn new(min_stack_size: usize) -> Self {
let previous_stack_size = recursive::get_minimum_stack_size();
recursive::set_minimum_stack_size(min_stack_size);
Self {
previous_stack_size,
}
}
}

impl Drop for StackGuard {
fn drop(&mut self) {
recursive::set_minimum_stack_size(self.previous_stack_size);
}
}
}

/// A stub implementation of the stack guard when the recursive protection
/// feature is not enabled
#[cfg(not(feature = "recursive_protection"))]
mod inner {
/// A stub implementation of the stack guard when the recursive protection
/// feature is not enabled that does nothing
pub struct StackGuard;

impl StackGuard {
/// A stub implementation of the stack guard when the recursive protection
/// feature is not enabled
pub fn new(_min_stack_size: usize) -> Self {
Self
}
}
}

0 comments on commit e99e02b

Please sign in to comment.