Skip to content

Commit

Permalink
Add #[recursive]
Browse files Browse the repository at this point in the history
  • Loading branch information
blaginin committed Nov 14, 2024
1 parent 2bb8144 commit 458f748
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ path = "src/lib.rs"

[features]
default = ["std"]
std = []
std = ["recursive"]
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
visitor = ["sqlparser_derive"]

[dependencies]
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
log = "0.4"
recursive = { version = "0.1.1", optional = true}

serde = { version = "1.0", features = ["derive"], optional = true }
# serde_json is only used in examples/cli, but we have to put it outside
# of dev-dependencies because of
Expand Down
1 change: 1 addition & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_
let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
#[cfg_attr(feature = "std", recursive::recursive)]
fn visit<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
Expand Down
1 change: 1 addition & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ impl fmt::Display for CastFormat {
}

impl fmt::Display for Expr {
#[cfg_attr(feature = "std", recursive::recursive)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expr::Identifier(s) => write!(f, "{s}"),
Expand Down
25 changes: 25 additions & 0 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,4 +884,29 @@ mod tests {
assert_eq!(actual, expected)
}
}

struct QuickVisitor;

impl Visitor for QuickVisitor {
type Break = ();
}

#[test]
fn overflow() {
let cond = (0..1000)
.map(|n| format!("X = {}", n))
.collect::<Vec<_>>()
.join(" OR ");
let sql = format!("SELECT x where {0}", cond);

let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
let s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();

let mut visitor = QuickVisitor {} ;
s.visit(&mut visitor);
}
}
11 changes: 11 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use sqlparser::dialect::{
use sqlparser::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError, ParserOptions};
use sqlparser::tokenizer::Tokenizer;
use std::iter;
use test_utils::{
all_dialects, all_dialects_where, alter_table_op, assert_eq_vec, call, expr_from_projection,
join, number, only, table, table_alias, TestedDialects,
Expand Down Expand Up @@ -11748,3 +11749,13 @@ fn parse_create_table_select() {
);
}
}

#[test]
fn overflow() {
let expr = iter::repeat("1").take(1000).collect::<Vec<_>>().join(" + ");
let sql = format!("SELECT {}", expr);

let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap();
let statement = statements.pop().unwrap();
assert_eq!(statement.to_string(), sql);
}

0 comments on commit 458f748

Please sign in to comment.