From 6e5b52cdefe59787914193d6b45668bf45ec34bb Mon Sep 17 00:00:00 2001 From: Sean Linsley Date: Wed, 11 Jan 2023 09:32:59 -0600 Subject: [PATCH 1/4] Support parsing complex queries with deeply nested ASTs --- Cargo.lock | 49 ++++++++++++++++++-------------------------- Cargo.toml | 4 ++-- build.rs | 4 +++- src/query.rs | 13 +++++++++++- tests/parse_tests.rs | 10 ++++----- 5 files changed, 42 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3688a3a..17204b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,12 +113,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" -[[package]] -name = "cc" -version = "1.0.73" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" - [[package]] name = "cexpr" version = "0.6.0" @@ -178,15 +172,6 @@ dependencies = [ "term", ] -[[package]] -name = "cmake" -version = "0.1.48" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8ad8cef104ac57b68b89df3208164d228503abbdce70f6880ffa3d970e7443a" -dependencies = [ - "cc", -] - [[package]] name = "constant_time_eq" version = "0.1.5" @@ -481,6 +466,16 @@ dependencies = [ "output_vt100", ] +[[package]] +name = "prettyplease" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e97e3215779627f01ee256d2fad52f3d95e8e1c11e9fc6fd08f7cd455d5d5c78" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.40" @@ -492,9 +487,8 @@ dependencies = [ [[package]] name = "prost" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71adf41db68aa0daaefc69bb30bcd68ded9b9abaad5d1fbb6304c4fb390e083e" +version = "0.11.6" +source = "git+https://github.com/pganalyze/prost?branch=recursion-limit-macro#4f02d843d0db6b6aa9df0d52ede33aa611cca2b3" dependencies = [ "bytes", "prost-derive", @@ -502,31 +496,29 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae5a4388762d5815a9fc0dea33c56b021cdc8dde0c55e0c9ca57197254b0cab" +version = "0.11.6" +source = "git+https://github.com/pganalyze/prost?branch=recursion-limit-macro#4f02d843d0db6b6aa9df0d52ede33aa611cca2b3" dependencies = [ "bytes", - "cfg-if", - "cmake", "heck", "itertools", "lazy_static", "log", "multimap", "petgraph", + "prettyplease", "prost", "prost-types", "regex", + "syn", "tempfile", "which", ] [[package]] name = "prost-derive" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b670f45da57fb8542ebdbb6105a925fe571b67f9e7ed9f47a06a84e72b4e7cc" +version = "0.11.6" +source = "git+https://github.com/pganalyze/prost?branch=recursion-limit-macro#4f02d843d0db6b6aa9df0d52ede33aa611cca2b3" dependencies = [ "anyhow", "itertools", @@ -537,9 +529,8 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d0a014229361011dc8e69c8a1ec6c2e8d0f2af7c91e3ea3f5b2170298461e68" +version = "0.11.6" +source = "git+https://github.com/pganalyze/prost?branch=recursion-limit-macro#4f02d843d0db6b6aa9df0d52ede33aa611cca2b3" dependencies = [ "bytes", "prost", diff --git a/Cargo.toml b/Cargo.toml index c84ea07..73a38ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ repository = "https://github.com/pganalyze/pg_query.rs" [dependencies] itertools = "0.10.3" -prost = "0.10.4" +prost = { git = "https://github.com/pganalyze/prost", branch = "recursion-limit-macro" } serde = { version = "1.0.139", features = ["derive"] } serde_json = "1.0.82" thiserror = "1.0.31" @@ -21,7 +21,7 @@ thiserror = "1.0.31" [build-dependencies] bindgen = "0.60.1" clippy = { version = "0.0.302", optional = true } -prost-build = "0.10.4" +prost-build = { git = "https://github.com/pganalyze/prost", branch = "recursion-limit-macro" } fs_extra = "1.2.0" [dev-dependencies] diff --git a/build.rs b/build.rs index 6213be9..d0f159e 100644 --- a/build.rs +++ b/build.rs @@ -57,7 +57,9 @@ fn main() -> Result<(), Box> { .write_to_file(out_dir.join("bindings.rs"))?; // Generate the protobuf definition - prost_build::compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; + let mut config = prost_build::Config::new(); + config.recursion_limit("ParseResult", 1000); + config.compile_protos(&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")], &[&out_protobuf_path])?; Ok(()) } diff --git a/src/query.rs b/src/query.rs index 25484f5..4e4f4e0 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,6 @@ use std::ffi::{CStr, CString}; use std::os::raw::{c_char, c_uint}; +use std::thread::Builder; use prost::Message; @@ -15,6 +16,13 @@ pub struct Fingerprint { pub hex: String, } +// Thread with a larger stack to avoid a stack overflow when decoding protobuf messages. +macro_rules! decode_thread { + ($exp: expr) => { + Builder::new().stack_size(100_000 * 0xFF).spawn(move || $exp).unwrap().join().unwrap() + } +} + /// Parses the given SQL statement into the given abstract syntax tree. /// /// # Example @@ -37,7 +45,10 @@ pub fn parse(statement: &str) -> Result { } else { let data = unsafe { std::slice::from_raw_parts(result.parse_tree.data as *const u8, result.parse_tree.len as usize) }; let stderr = unsafe { CStr::from_ptr(result.stderr_buffer) }.to_string_lossy().to_string(); - protobuf::ParseResult::decode(data).map_err(Error::Decode).and_then(|result| Ok(ParseResult::new(result, stderr))) + match decode_thread!(protobuf::ParseResult::decode(data)) { + Ok(result) => Ok(ParseResult::new(result, stderr)), + Err(error) => Err(Error::Decode(error)), + } }; unsafe { pg_query_free_protobuf_parse_result(result) }; parse_result diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index 9444bb2..7eb7dd5 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -31,15 +31,15 @@ fn it_handles_errors() { } #[test] -fn it_handles_recursion_error() { +fn it_handles_recursion_without_error_1() { let query = "SELECT a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(b))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))"; - parse(query).err().unwrap(); - // TODO: unsure how to unwrap the private fields on a protobuf decode error - // assert_eq!(error, Error::Decode("recursion limit reached".into())); + let result = parse(query).unwrap(); + assert_eq!(result.tables().len(), 0); + assert_eq!(result.statement_types(), ["SelectStmt"]); } #[test] -fn it_handles_recursion_without_error() { +fn it_handles_recursion_without_error_2() { // The Ruby version of pg_query fails here because of Ruby protobuf limitations let query = r#"SELECT * FROM "t0" JOIN "t1" ON (1) JOIN "t2" ON (1) JOIN "t3" ON (1) JOIN "t4" ON (1) JOIN "t5" ON (1) From e35c4028aa84a9d7caf5b185177732674ae74964 Mon Sep 17 00:00:00 2001 From: Sean Linsley Date: Wed, 11 Jan 2023 09:52:48 -0600 Subject: [PATCH 2/4] cargo fmt --- src/query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query.rs b/src/query.rs index 4e4f4e0..864313f 100644 --- a/src/query.rs +++ b/src/query.rs @@ -20,7 +20,7 @@ pub struct Fingerprint { macro_rules! decode_thread { ($exp: expr) => { Builder::new().stack_size(100_000 * 0xFF).spawn(move || $exp).unwrap().join().unwrap() - } + }; } /// Parses the given SQL statement into the given abstract syntax tree. From 26c42e5dc8de4343c8864097700e274b0b3a1f9c Mon Sep 17 00:00:00 2001 From: Sean Linsley Date: Wed, 11 Jan 2023 10:02:58 -0600 Subject: [PATCH 3/4] ensure protobuf compiler is installed --- .github/workflows/main.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2ead9a4..c170488 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,6 +21,9 @@ jobs: - nightly steps: + - name: Install protobuf-compiler + run: sudo apt-get install -y protobuf-compiler + - uses: actions/checkout@v2 with: submodules: recursive @@ -54,6 +57,9 @@ jobs: name: Check file formatting and style runs-on: ubuntu-latest steps: + - name: Install protobuf-compiler + run: sudo apt-get install -y protobuf-compiler + - uses: actions/checkout@v2 with: submodules: recursive From 2f09c574d8e27c1f43ba7269dfb405bc66d13b9a Mon Sep 17 00:00:00 2001 From: Sean Linsley Date: Fri, 13 Jan 2023 14:05:23 -0600 Subject: [PATCH 4/4] let downstream code decide how to handle stack overflows --- README.md | 10 ++++++++++ src/query.rs | 10 +--------- tests/parse_tests.rs | 5 ++++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ba60dcd..25ff2fc 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,16 @@ let result = pg_query::parse(query).unwrap(); assert_eq!(result.truncate(32).unwrap(), "INSERT INTO x (...) VALUES (...)"); ``` +## Caveats + +When parsing very complex queries you may run into a stack overflow. This can be worked around by using a thread with a custom stack size ([stdlib](https://doc.rust-lang.org/std/thread/index.html#stack-size), [tokio](https://docs.rs/tokio/latest/tokio/runtime/struct.Builder.html#method.thread_stack_size)), or using the stacker crate to resize the main thread's stack: + +```rust +stacker::grow(20 * 1024 * 1024, || pg_query::parse(query)) +``` + +However, a sufficiently complex query could still run into a stack overflow after you increase the stack size. With some work it may be possible to add an adapter API to the prost crate in order to dynamically increase the stack size as needed like [serde_stacker](https://crates.io/crates/serde_stacker) does (if anyone wants to take that on). + ## Credits Thanks to [Paul Mason](https://github.com/paupino) for his work on [pg_parse](https://github.com/paupino/pg_parse) that this crate is based on. diff --git a/src/query.rs b/src/query.rs index 864313f..304710b 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,5 @@ use std::ffi::{CStr, CString}; use std::os::raw::{c_char, c_uint}; -use std::thread::Builder; use prost::Message; @@ -16,13 +15,6 @@ pub struct Fingerprint { pub hex: String, } -// Thread with a larger stack to avoid a stack overflow when decoding protobuf messages. -macro_rules! decode_thread { - ($exp: expr) => { - Builder::new().stack_size(100_000 * 0xFF).spawn(move || $exp).unwrap().join().unwrap() - }; -} - /// Parses the given SQL statement into the given abstract syntax tree. /// /// # Example @@ -45,7 +37,7 @@ pub fn parse(statement: &str) -> Result { } else { let data = unsafe { std::slice::from_raw_parts(result.parse_tree.data as *const u8, result.parse_tree.len as usize) }; let stderr = unsafe { CStr::from_ptr(result.stderr_buffer) }.to_string_lossy().to_string(); - match decode_thread!(protobuf::ParseResult::decode(data)) { + match protobuf::ParseResult::decode(data) { Ok(result) => Ok(ParseResult::new(result, stderr)), Err(error) => Err(Error::Decode(error)), } diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index 7eb7dd5..320bb0e 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -4,6 +4,9 @@ #[cfg(test)] use itertools::sorted; +#[cfg(test)] +use std::thread::Builder; + use pg_query::{ parse, protobuf::{self, a_const::Val}, @@ -33,7 +36,7 @@ fn it_handles_errors() { #[test] fn it_handles_recursion_without_error_1() { let query = "SELECT a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(b))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))"; - let result = parse(query).unwrap(); + let result = Builder::new().stack_size(20 * 1024 * 1024).spawn(move || parse(query)).unwrap().join().unwrap().unwrap(); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["SelectStmt"]); }