diff --git a/Cargo.lock b/Cargo.lock index a2a868b..610be31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -389,6 +389,17 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bigdecimal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -451,11 +462,35 @@ checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-targets", ] +[[package]] +name = "chrono-tz" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1369bc6b9e9a7dfdae2055f6ec151fe9c554a9d23d357c0237cee2e25eaabb7" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2f5ebdc942f57ed96d560a6d1a459bae5851102a25d5bf89dc04ae453e31ecf" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + [[package]] name = "clap" version = "4.4.6" @@ -689,6 +724,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -1115,6 +1163,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "iterable" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151dfd6ab7dff5ca5567d82041bb286f07469ece85c1e2444a6d26d7057a65f" +dependencies = [ + "itertools", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1148,6 +1205,7 @@ dependencies = [ "once_cell", "peg", "phf", + "prusto", "regex", "rusqlite", "serde", @@ -1573,6 +1631,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +dependencies = [ + "regex", +] + [[package]] name = "peg" version = "0.8.1" @@ -1634,6 +1701,16 @@ dependencies = [ "phf_shared", ] +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + [[package]] name = "phf_generator" version = "0.11.2" @@ -1726,6 +1803,41 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prusto" +version = "0.5.0" +source = "git+https://github.com/nooberfsh/prusto.git#adf326f58b0c594b9b9e52b12d3bb03cd907980d" +dependencies = [ + "bigdecimal", + "chrono", + "chrono-tz", + "derive_more", + "futures", + "http", + "iterable", + "lazy_static", + "log", + "prusto-macros", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "urlencoding", + "uuid", +] + +[[package]] +name = "prusto-macros" +version = "0.2.0" +source = "git+https://github.com/nooberfsh/prusto.git#adf326f58b0c594b9b9e52b12d3bb03cd907980d" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "quick-xml" version = "0.28.2" @@ -2145,6 +2257,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.1.0" @@ -2466,7 +2587,9 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.4", "tokio-macros", "windows-sys", @@ -2640,6 +2763,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2653,6 +2782,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ "getrandom", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 866162d..0eccb62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ joinery_macros = { path = "joinery_macros" } once_cell = "1.18.0" peg = "0.8.1" phf = { version = "0.11.2", features = ["macros"] } +# Waiting on https://github.com/nooberfsh/prusto/issues/33 +prusto = { git = "https://github.com/nooberfsh/prusto.git" } regex = "1.10.0" rusqlite = { version = "0.29.0", features = ["bundled", "functions", "vtab"] } serde = { version = "1.0.188", features = ["derive"] } diff --git a/joinery_macros/src/lib.rs b/joinery_macros/src/lib.rs index 49778e1..9e5fb55 100644 --- a/joinery_macros/src/lib.rs +++ b/joinery_macros/src/lib.rs @@ -14,10 +14,11 @@ pub fn emit_macro_derive(input: TokenStream) -> TokenStream { fn impl_emit_macro(ast: &syn::DeriveInput) -> TokenStream2 { let name = &ast.ident; + let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); quote! { - impl Emit for #name { + impl #impl_generics Emit for #name #ty_generics #where_clause { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { - <#name as EmitDefault>::emit_default(self, t, f) + <#name #ty_generics as EmitDefault>::emit_default(self, t, f) } } } @@ -35,9 +36,10 @@ pub fn emit_default_macro_derive(input: TokenStream) -> TokenStream { /// TODO: If we see `#[emit(skip)]` on a field, we should skip it. fn impl_emit_default_macro(ast: &syn::DeriveInput) -> TokenStream2 { let name = &ast.ident; + let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl(); let implementation = emit_default_body(name, &ast.data); quote! { - impl EmitDefault for #name { + impl #impl_generics EmitDefault for #name #ty_generics #where_clause { fn emit_default(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { #implementation } diff --git a/src/analyze.rs b/src/analyze.rs index 76cae52..050cfd5 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -37,7 +37,7 @@ impl FunctionCallCounts { name.push('*'); } else { // Push '_' separated by ','. - for i in 0..function_call.args.nodes.len() { + for (i, _) in function_call.args.node_iter().enumerate() { if i > 0 { name.push(','); } @@ -62,7 +62,7 @@ impl FunctionCallCounts { .unescaped_bigquery() .to_ascii_uppercase(), ); - for i in 0..special_date_function_call.args.nodes.len() { + for (i, _) in special_date_function_call.args.node_iter().enumerate() { if i > 0 { name.push(','); } diff --git a/src/ast.rs b/src/ast.rs index 6bd24f7..6ac9c0c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -20,6 +20,7 @@ use std::{ borrow::Cow, fmt::{self, Display as _}, + mem::take, ops::Range, }; @@ -83,6 +84,11 @@ static KEYWORDS: phf::Set<&'static str> = phf::phf_set! { /// to construct from within our parser. type Span = Range; +/// Used to represent a missing span. +pub fn span_none() -> Span { + usize::MAX..usize::MAX +} + /// A function that compares two strings for equality. type StrCmp = dyn Fn(&str, &str) -> bool; @@ -93,6 +99,7 @@ pub enum Target { BigQuery, Snowflake, SQLite3, + Trino, } impl Target { @@ -108,6 +115,7 @@ impl fmt::Display for Target { Target::BigQuery => write!(f, "bigquery"), Target::Snowflake => write!(f, "snowflake"), Target::SQLite3 => write!(f, "sqlite3"), + Target::Trino => write!(f, "trino"), } } } @@ -292,7 +300,7 @@ impl Emit for Token { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { Target::BigQuery => write!(f, "{}", self.text), - Target::Snowflake | Target::SQLite3 => { + Target::Snowflake | Target::SQLite3 | Target::Trino => { // Write out the token itself. write!(f, "{}", self.token_str())?; @@ -317,6 +325,18 @@ impl Emit for Token { } } +/// A node type, for use with [`NodeVec`]. +pub trait Node: Clone + fmt::Debug + Drive + DriveMut + Emit + 'static {} + +impl Node for T {} + +/// Either a node or a separator token. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +pub enum NodeOrSep { + Node(T), + Sep(Token), +} + /// A vector which contains a list of nodes, along with any whitespace that /// appears after a node but before any separating punctuation. This can be /// used either with or without a final trailing separator. @@ -330,72 +350,140 @@ impl Emit for Token { /// nodes and separators, and define custom node-only and separator-only /// iterators. #[derive(Debug)] -pub struct NodeVec { - pub nodes: Vec, - pub separators: Vec, +pub struct NodeVec { + /// The separator to use when adding items. + pub separator: &'static str, + /// The nodes and separators in this vector. + pub items: Vec>, } -impl NodeVec { - /// Iterate over the nodes in this [`NodeVec`]. - pub fn iter(&self) -> impl Iterator { - self.nodes.iter() +impl NodeVec { + /// Create a new [`NodeVec`] with the given separator. + pub fn new(separator: &'static str) -> NodeVec { + NodeVec { + separator, + items: vec![], + } } -} -impl Clone for NodeVec { - fn clone(&self) -> Self { + /// Take the elements from this [`NodeVec`], leaving it empty, and return a new + /// [`NodeVec`] containing the taken elements. + pub fn take(&mut self) -> NodeVec { NodeVec { - nodes: self.nodes.clone(), - separators: self.separators.clone(), + separator: self.separator, + items: take(&mut self.items), + } + } + + /// Add a node to this [`NodeVec`]. + pub fn push(&mut self, node: T) { + if let Some(NodeOrSep::Node(_)) = self.items.last() { + self.items.push(NodeOrSep::Sep(Token { + span: span_none(), + ws_offset: self.separator.len(), + text: self.separator.to_owned(), + })); + } + self.items.push(NodeOrSep::Node(node)); + } + + /// Add a node or a separator to this [`NodeVec`], inserting or removing + /// separators as needed to ensure that the vector is well-formed. + pub fn push_node_or_sep(&mut self, node_or_sep: NodeOrSep) { + match node_or_sep { + NodeOrSep::Node(_) => { + if let Some(NodeOrSep::Node(_)) = self.items.last() { + self.items.push(NodeOrSep::Sep(Token { + span: span_none(), + ws_offset: self.separator.len(), + text: self.separator.to_owned(), + })); + } + } + NodeOrSep::Sep(_) => { + if let Some(NodeOrSep::Sep(_)) = self.items.last() { + self.items.pop(); + } + } } + self.items.push(node_or_sep); + } + + /// Iterate over the node and separators in this [`NodeVec`]. + pub fn iter(&self) -> impl Iterator> { + self.items.iter() + } + + /// Iterate over just the nodes in this [`NodeVec`]. + pub fn node_iter(&self) -> impl Iterator { + self.items.iter().filter_map(|item| match item { + NodeOrSep::Node(node) => Some(node), + NodeOrSep::Sep(_) => None, + }) + } + + /// Iterate over nodes and separators separately. Used internally for + /// parsing dotted names. + fn into_node_and_sep_iters(self) -> (impl Iterator, impl Iterator) { + let mut nodes = vec![]; + let mut seps = vec![]; + for item in self.items { + match item { + NodeOrSep::Node(node) => nodes.push(node), + NodeOrSep::Sep(token) => seps.push(token), + } + } + (nodes.into_iter(), seps.into_iter()) } } -impl Default for NodeVec { - fn default() -> Self { +impl Clone for NodeVec { + fn clone(&self) -> Self { NodeVec { - nodes: Vec::new(), - separators: Vec::new(), + separator: self.separator, + items: self.items.clone(), } } } -impl<'a, T: fmt::Debug> IntoIterator for &'a NodeVec { - type Item = &'a T; - type IntoIter = std::slice::Iter<'a, T>; +impl IntoIterator for NodeVec { + type Item = NodeOrSep; + type IntoIter = std::vec::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +impl<'a, T: Node> IntoIterator for &'a NodeVec { + type Item = &'a NodeOrSep; + type IntoIter = std::slice::Iter<'a, NodeOrSep>; fn into_iter(self) -> Self::IntoIter { - self.nodes.iter() + self.items.iter() } } // Mutable iterator for `DriveMut`. -impl<'a, T: fmt::Debug> IntoIterator for &'a mut NodeVec { - type Item = &'a mut T; - type IntoIter = std::slice::IterMut<'a, T>; +impl<'a, T: Node> IntoIterator for &'a mut NodeVec { + type Item = &'a mut NodeOrSep; + type IntoIter = std::slice::IterMut<'a, NodeOrSep>; fn into_iter(self) -> Self::IntoIter { - self.nodes.iter_mut() + self.items.iter_mut() } } -impl Emit for NodeVec { +impl Emit for NodeVec { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (i, node) in self.nodes.iter().enumerate() { - write!(f, "{}", t.f(node))?; - if i < self.separators.len() { - let sep = &self.separators[i]; - if i + 1 < self.nodes.len() { - write!(f, "{}", t.f(sep))?; - } else { - // Trailing separator. - match t { - Target::BigQuery => write!(f, "{}", t.f(sep))?, - Target::Snowflake | Target::SQLite3 => { - write!(f, "{}", t.f(&sep.with_erased_token_str()))? - } - } + for (i, node_or_sep) in self.items.iter().enumerate() { + let is_last = i + 1 == self.items.len(); + match node_or_sep { + NodeOrSep::Node(node) => node.emit(t, f)?, + NodeOrSep::Sep(sep) if is_last && t != Target::BigQuery => { + sep.with_erased_token_str().emit(t, f)? } + NodeOrSep::Sep(sep) => sep.emit(t, f)?, } } Ok(()) @@ -455,7 +543,7 @@ impl Emit for Identifier { // Snowflake and SQLite3 use double quoted identifiers and // escape quotes by doubling them. Neither allows backslash // escapes here, though Snowflake does in strings. - Target::Snowflake | Target::SQLite3 => write!( + Target::Snowflake | Target::SQLite3 | Target::Trino => write!( f, "{}{}", SQLite3Ident(&self.text), @@ -551,6 +639,7 @@ pub enum Statement { InsertInto(InsertIntoStatement), CreateTable(CreateTableStatement), CreateView(CreateViewStatement), + DropTable(DropTableStatement), DropView(DropViewStatement), } @@ -1461,6 +1550,30 @@ impl Emit for DataType { ) } }, + Target::Trino => match self { + DataType::Bool(token) => token.with_token_str("BOOLEAN").emit(t, f), + DataType::Bytes(token) => token.with_token_str("VARBINARY").emit(t, f), + DataType::Date(token) => token.emit(t, f), + DataType::Datetime(token) => token.with_token_str("TIMESTAMP").emit(t, f), + DataType::Float64(token) => token.with_token_str("DOUBLE").emit(t, f), + DataType::Geography(token) => token.with_token_str("JSON").emit(t, f), + DataType::Int64(token) => token.with_token_str("BIGINT").emit(t, f), + // TODO: This cannot be done safely in Trino, because you always + // need to specify the precision and where to put the decimal + // place. + DataType::Numeric(token) => token.with_token_str("DECIMAL(?,?)").emit(t, f), + DataType::String(token) => token.with_token_str("VARCHAR").emit(t, f), + DataType::Time(token) => token.emit(t, f), + DataType::Timestamp(token) => { + token.with_token_str("TIMESTAMP WITH TIME ZONE").emit(t, f) + } + // TODO: Can we declare the element type? + DataType::Array { array_token, .. } => array_token.ensure_ws().emit(t, f), + // TODO: I think we can translate the column types? + DataType::Struct { struct_token, .. } => { + struct_token.with_token_str("ROW").emit(t, f) + } + }, _ => self.emit_default(t, f), } } @@ -1730,7 +1843,7 @@ pub struct CreateTableStatement { impl Emit for CreateTableStatement { fn emit(&self, t: Target, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { - Target::SQLite3 if self.or_replace.is_some() => { + Target::SQLite3 | Target::Trino if self.or_replace.is_some() => { // We need to convert this to a `DROP TABLE IF EXISTS` statement. write!(f, "DROP TABLE IF EXISTS {};", t.f(&self.table_name))?; } @@ -1810,14 +1923,31 @@ pub struct ColumnDefinition { pub data_type: DataType, } +/// A `DROP TABLE` statement. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +pub struct DropTableStatement { + pub drop_token: Token, + pub table_token: Token, + pub if_exists: Option, + pub table_name: TableName, +} + /// A `DROP VIEW` statement. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] pub struct DropViewStatement { pub drop_token: Token, pub view_token: Token, + pub if_exists: Option, pub view_name: TableName, } +/// An `IF EXISTS` modifier. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault)] +pub struct IfExists { + pub if_token: Token, + pub exists_token: Token, +} + /// Parse BigQuery SQL. pub fn parse_sql(filename: &str, sql: &str) -> Result { // Parse with or without tracing, as appropriate. The tracing code throws @@ -1890,6 +2020,7 @@ peg::parser! { / d:delete_from_statement() { Statement::DeleteFrom(d) } / c:create_table_statement() { Statement::CreateTable(c) } / c:create_view_statement() { Statement::CreateView(c) } + / d:drop_table_statement() { Statement::DropTable(d) } / d:drop_view_statement() { Statement::DropView(d) } rule query_statement() -> QueryStatement @@ -2369,7 +2500,7 @@ peg::parser! { FunctionCall { name, paren1, - args: args.unwrap_or_default(), + args: args.unwrap_or_else(|| NodeVec::new(",")), paren2, over_clause, } @@ -2724,33 +2855,51 @@ peg::parser! { } } + rule drop_table_statement() -> DropTableStatement + = drop_token:k("DROP") table_token:k("TABLE") if_exists:if_exists()? table_name:table_name() { + DropTableStatement { + drop_token, + table_token, + if_exists, + table_name, + } + } + rule drop_view_statement() -> DropViewStatement // Oddly, BigQuery accepts `DELETE VIEW`. - = drop_token:(k("DROP") / k("DELETE")) view_token:k("VIEW") view_name:table_name() { + = drop_token:(k("DROP") / k("DELETE")) view_token:k("VIEW") if_exists:if_exists()? view_name:table_name() { DropViewStatement { // Fix this at parse time. Nobody wants `DELETE VIEW`. drop_token: drop_token.with_token_str("DROP"), view_token, + if_exists, view_name, } } + rule if_exists() -> IfExists + = if_token:k("IF") exists_token:k("EXISTS") { + IfExists { + if_token, + exists_token, + } + } + /// A table name, such as `t1` or `project-123.dataset1.table2`. rule table_name() -> TableName // We handle this manually because of PEG backtracking limitations. = dotted:dotted_name() {? - let len = dotted.nodes.len(); - let mut nodes = dotted.nodes.into_iter(); - let mut dots = dotted.separators.into_iter(); + let len = dotted.items.len(); + let (mut nodes, mut dots) = dotted.into_node_and_sep_iters(); if len == 1 { Ok(TableName::Table { table: nodes.next().unwrap() }) - } else if len == 2 { + } else if len == 3 { Ok(TableName::DatasetTable { dataset: nodes.next().unwrap(), dot: dots.next().unwrap(), table: nodes.next().unwrap(), }) - } else if len == 3 { + } else if len == 5 { Ok(TableName::ProjectDatasetTable { project: nodes.next().unwrap(), dot1: dots.next().unwrap(), @@ -2767,16 +2916,15 @@ peg::parser! { rule table_and_column_name() -> TableAndColumnName // We handle this manually because of PEG backtracking limitations. = dotted:dotted_name() {? - let len = dotted.nodes.len(); - let mut nodes = dotted.nodes.into_iter(); - let mut dots = dotted.separators.into_iter(); - if len == 2 { + let len = dotted.items.len(); + let (mut nodes, mut dots) = dotted.into_node_and_sep_iters(); + if len == 3 { Ok(TableAndColumnName { table_name: TableName::Table { table: nodes.next().unwrap() }, dot: dots.next().unwrap(), column_name: nodes.next().unwrap(), }) - } else if len == 3 { + } else if len == 5 { Ok(TableAndColumnName { table_name: TableName::DatasetTable { dataset: nodes.next().unwrap(), @@ -2786,7 +2934,7 @@ peg::parser! { dot: dots.next().unwrap(), column_name: nodes.next().unwrap(), }) - } else if len == 4 { + } else if len == 7 { Ok(TableAndColumnName { table_name: TableName::ProjectDatasetTable { project: nodes.next().unwrap(), @@ -2806,20 +2954,19 @@ peg::parser! { rule dotted_function_name() -> FunctionName // We handle this manually because of PEG backtracking limitations. = dotted:dotted_name() {? - let len = dotted.nodes.len(); - let mut nodes = dotted.nodes.into_iter(); - let mut dots = dotted.separators.into_iter(); + let len = dotted.items.len(); + let (mut nodes, mut dots) = dotted.into_node_and_sep_iters(); if len == 1 { Ok(FunctionName::Function { function: nodes.next().unwrap() }) - } else if len == 2 { + } else if len == 3 { Ok(FunctionName::DatasetFunction { dataset: nodes.next().unwrap(), dot: dots.next().unwrap(), function: nodes.next().unwrap(), }) - } else if len == 3 { + } else if len == 5 { Ok(FunctionName::ProjectDatasetFunction { project: nodes.next().unwrap(), dot1: dots.next().unwrap(), @@ -2919,24 +3066,23 @@ peg::parser! { rule hex_digit() = ['0'..='9' | 'a'..='f' | 'A'..='F'] /// Punctuation separated list. Does not allow a trailing separator. - rule sep(item: rule, separator: &'static str) -> NodeVec - = first:item() items:(sep:t(separator) item:item() { (sep, item) })* { - let mut nodes = Vec::new(); - let mut separators = Vec::new(); - nodes.push(first); - for (sep, item) in items { - separators.push(sep); - nodes.push(item); + rule sep(node: rule, separator: &'static str) -> NodeVec + = first:node() rest:(sep:t(separator) node:node() { (sep, node) })* { + let mut items = Vec::new(); + items.push(NodeOrSep::Node(first)); + for (sep, node) in rest { + items.push(NodeOrSep::Sep(sep)); + items.push(NodeOrSep::Node(node)); } - NodeVec { nodes, separators } + NodeVec { separator, items } } /// Punctuation separated list. Allows a trailing separator. - rule sep_opt_trailing(item: rule, separator: &'static str) -> NodeVec + rule sep_opt_trailing(item: rule, separator: &'static str) -> NodeVec = list:sep(item, separator) trailing_sep:t(separator)? { let mut list = list; if let Some(trailing_sep) = trailing_sep { - list.separators.push(trailing_sep); + list.items.push(NodeOrSep::Sep(trailing_sep)); } list } @@ -3166,7 +3312,7 @@ mod tests { (r#"DELETE FROM t AS t2 WHERE a = 0"#, None), (r#"CREATE OR REPLACE TABLE t2 (a INT64, b INT64)"#, None), (r#"CREATE OR REPLACE TABLE t2 AS (SELECT * FROM t)"#, None), - //(r#"DROP TABLE t2"#, None), + (r#"DROP TABLE t2"#, None), (r#"CREATE OR REPLACE VIEW v AS (SELECT * FROM t)"#, None), (r#"DROP VIEW v"#, None), ( diff --git a/src/cmd/sql_test.rs b/src/cmd/sql_test.rs index cd3446d..c52806b 100644 --- a/src/cmd/sql_test.rs +++ b/src/cmd/sql_test.rs @@ -212,7 +212,7 @@ struct OutputTablePair { fn find_output_tables(ast: &ast::SqlProgram) -> Result> { let mut tables = Vec::>>::default(); - for s in &ast.statements { + for s in ast.statements.node_iter() { let name = match s { ast::Statement::CreateTable(CreateTableStatement { table_name, .. }) => table_name, ast::Statement::CreateView(CreateViewStatement { view_name, .. }) => view_name, diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index fb6580a..36fbc6a 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -7,14 +7,19 @@ use async_trait::async_trait; use crate::{ ast::{self, Emit, Target}, errors::{format_err, Error, Result}, + transforms::Transform, }; -use self::snowflake::{SnowflakeLocator, SNOWFLAKE_LOCATOR_PREFIX}; -use self::sqlite3::{SQLite3Locator, SQLITE3_LOCATOR_PREFIX}; +use self::{ + snowflake::{SnowflakeLocator, SNOWFLAKE_LOCATOR_PREFIX}, + sqlite3::{SQLite3Locator, SQLITE3_LOCATOR_PREFIX}, + trino::{TrinoLocator, TRINO_LOCATOR_PREFIX}, +}; pub mod bigquery; pub mod snowflake; pub mod sqlite3; +pub mod trino; /// A URL-like locator for a database. #[async_trait] @@ -37,6 +42,7 @@ impl FromStr for Box { match prefix { SQLITE3_LOCATOR_PREFIX => Ok(Box::new(s.parse::()?)), SNOWFLAKE_LOCATOR_PREFIX => Ok(Box::new(s.parse::()?)), + TRINO_LOCATOR_PREFIX => Ok(Box::new(s.parse::()?)), _ => Err(format_err!("unsupported database type: {}", s)), } } @@ -76,17 +82,35 @@ pub trait Driver: Send + Sync + 'static { self.execute_native_sql_statement(&sql).await } + /// Get a list of transformations that should be applied to the AST before + /// executing it. + fn transforms(&self) -> Vec> { + vec![] + } + /// Rewrite an AST to convert function names, etc., into versions that can /// be passed to [`Emitted::emit_to_string`] for this database. This allows /// us to do less database-specific work in [`Emit::emit`], and more in the /// database drivers themselves. This can't change lexical syntax, but it /// can change the structure of the AST. fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result> { - // Default implementation does nothing. - Ok(RewrittenAst { - extra_native_sql: vec![], - ast: Cow::Borrowed(ast), - }) + let transforms = self.transforms(); + if transforms.is_empty() { + return Ok(RewrittenAst { + extra_native_sql: vec![], + ast: Cow::Borrowed(ast), + }); + } else { + let mut rewritten = ast.clone(); + let mut extra_native_sql = vec![]; + for transform in transforms { + extra_native_sql.extend(transform.transform(&mut rewritten)?); + } + Ok(RewrittenAst { + extra_native_sql, + ast: Cow::Owned(rewritten), + }) + } } /// Drop a table if it exists. diff --git a/src/drivers/snowflake/mod.rs b/src/drivers/snowflake/mod.rs index c82aa79..b2ebf9e 100644 --- a/src/drivers/snowflake/mod.rs +++ b/src/drivers/snowflake/mod.rs @@ -1,10 +1,9 @@ //! A Snowflake driver. -use std::{borrow::Cow, env, fmt, str::FromStr}; +use std::{env, fmt, str::FromStr}; use arrow_json::writer::record_batches_to_json_rows; use async_trait::async_trait; -use derive_visitor::DriveMut; use once_cell::sync::Lazy; use regex::Regex; use serde_json::Value; @@ -14,15 +13,38 @@ use tracing::{debug, instrument}; use crate::{ ast::{self, Emit, Target}, errors::{format_err, Context, Error, Result}, + transforms::{self, Transform, Udf}, }; -use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator, RewrittenAst}; - -mod rename_functions; +use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator}; /// Locator prefix for Snowflake. pub const SNOWFLAKE_LOCATOR_PREFIX: &str = "snowflake:"; +// A `phf_map!` of BigQuery function names to Snowflake function names. Use +// this for simple renaming. +static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! { + "ARRAY_LENGTH" => "ARRAY_SIZE", + "GENERATE_UUID" => "UUID_STRING", + "REGEXP_EXTRACT" => "REGEXP_SUBSTR", + "SHA256" => "SHA2_BINARY", // Second argument defaults to SHA256. +}; + +/// A `phf_map!` of BigQuery UDF names to Snowflake UDFs. Use this when we +/// actually need to create a UDF as a helper function. +static UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! { + "RAND" => &Udf { decl: "RAND() RETURNS FLOAT", sql: "UNIFORM(0::float, 1::float, RANDOM())" }, + "TO_HEX" => &Udf { decl: "TO_HEX(b BINARY) RETURNS STRING", sql: "HEX_ENCODE(b, 0)" }, +}; + +/// Format a UDF for use with Snowflake. +fn format_udf(udf: &Udf) -> String { + format!( + "CREATE OR REPLACE TEMP FUNCTION {} AS $$\n{}\n$$\n", + udf.decl, udf.sql + ) +} + /// A locator for a Snowflake database. /// /// For now, we will format these as: @@ -208,7 +230,7 @@ impl Driver for SnowflakeDriver { } // We can only execute one statement at a time. - for statement in &rewritten.ast.statements { + for statement in rewritten.ast.statements.node_iter() { let sql = statement.emit_to_string(self.target()); self.execute_native_sql_statement(&sql).await?; } @@ -220,19 +242,12 @@ impl Driver for SnowflakeDriver { .context("could not end Snowflake session") } - fn rewrite_ast<'ast>(&self, ast: &'ast ast::SqlProgram) -> Result> { - let mut ast = ast.clone(); - let mut renamer = rename_functions::RenameFunctions::default(); - ast.drive_mut(&mut renamer); - let extra_native_sql = renamer - .udfs - .values() - .map(|udf| udf.to_sql()) - .collect::>(); - Ok(RewrittenAst { - extra_native_sql, - ast: Cow::Owned(ast), - }) + fn transforms(&self) -> Vec> { + vec![Box::new(transforms::RenameFunctions::new( + &FUNCTION_NAMES, + &UDFS, + &format_udf, + ))] } #[instrument(skip(self))] diff --git a/src/drivers/snowflake/rename_functions.rs b/src/drivers/snowflake/rename_functions.rs deleted file mode 100644 index 3dae619..0000000 --- a/src/drivers/snowflake/rename_functions.rs +++ /dev/null @@ -1,68 +0,0 @@ -//! A simple tree-walker that renames functions to their Snowflake equivalents. - -use std::collections::HashMap; - -use derive_visitor::VisitorMut; - -use crate::ast::{FunctionName, Identifier}; - -// A `phf_map!` of BigQuery function names to Snowflake function names. Use -// this for simple renaming. -static FUNCTION_NAMES: phf::Map<&'static str, &'static str> = phf::phf_map! { - "ARRAY_LENGTH" => "ARRAY_SIZE", - "GENERATE_UUID" => "UUID_STRING", - "REGEXP_EXTRACT" => "REGEXP_SUBSTR", - "SHA256" => "SHA2_BINARY", // Second argument defaults to SHA256. -}; - -/// A Snowflake UDF (user-defined function). -pub struct Udf { - pub decl: &'static str, - pub sql: &'static str, -} - -impl Udf { - /// Generate the SQL to create this UDF. - pub fn to_sql(&self) -> String { - format!( - "CREATE OR REPLACE TEMP FUNCTION {} AS $$\n{}\n$$\n", - self.decl, self.sql - ) - } -} - -/// A `phf_map!` of BigQuery UDF names to Snowflake UDFs. Use this when we -/// actually need to create a UDF as a helper function. -static UDFS: phf::Map<&'static str, &'static Udf> = phf::phf_map! { - "RAND" => &Udf { decl: "RAND() RETURNS FLOAT", sql: "UNIFORM(0::float, 1::float, RANDOM())" }, - "TO_HEX" => &Udf { decl: "TO_HEX(b BINARY) RETURNS STRING", sql: "HEX_ENCODE(b, 0)" }, -}; - -#[derive(Default, VisitorMut)] -#[visitor(FunctionName(enter))] -pub struct RenameFunctions { - // UDFs that we need to create, if we haven't already. - pub udfs: HashMap, -} - -impl RenameFunctions { - fn enter_function_name(&mut self, function_name: &mut FunctionName) { - if let FunctionName::Function { function } = function_name { - let name = function.unescaped_bigquery().to_ascii_uppercase(); - if let Some(snowflake_name) = FUNCTION_NAMES.get(&name) { - // Rename the function. - let orig_ident = function_name.function_identifier(); - *function_name = FunctionName::Function { - function: Identifier { - token: orig_ident.token.with_token_str(snowflake_name), - text: snowflake_name.to_string(), - }, - }; - } else if let Some(udf) = UDFS.get(&name) { - // We'll need a UDF, so add it to our list it if isn't already - // there. - self.udfs.insert(name, udf); - } - } - } -} diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs new file mode 100644 index 0000000..55bfca6 --- /dev/null +++ b/src/drivers/trino/mod.rs @@ -0,0 +1,254 @@ +//! Trino and maybe Presto driver. + +use std::{fmt, str::FromStr}; + +use async_trait::async_trait; +use once_cell::sync::Lazy; +use prusto::{Client, ClientBuilder, Presto, Row}; +use regex::Regex; +use tracing::debug; + +use crate::{ + ast::{self, Emit, Target}, + drivers::sqlite3::SQLite3String, + errors::{format_err, Context, Error, Result}, + transforms::{self, Transform}, +}; + +use super::{sqlite3::SQLite3Ident, Column, Driver, DriverImpl, Locator}; + +/// Our locator prefix. +pub const TRINO_LOCATOR_PREFIX: &str = "trino:"; + +/// A locator for a Trino database. May or may not also work for Presto. +#[derive(Debug)] +pub struct TrinoLocator { + user: String, + host: String, + port: Option, + catalog: String, + schema: String, +} + +impl fmt::Display for TrinoLocator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "trino://{}@{}", self.user, self.host,)?; + if let Some(port) = self.port { + write!(f, ":{}", port)?; + } + write!(f, "/{}/{}", self.catalog, self.schema) + } +} + +// Use `once_cell` and `regex` to parse our locator. +impl FromStr for TrinoLocator { + type Err = Error; + + fn from_str(s: &str) -> Result { + static LOCATOR_RE: Lazy = Lazy::new(|| { + Regex::new( + r"(?x) + ^trino:// + (?P[^@]+)@ + (?P[^:/]+) + (?::(?P\d+))? + /(?P[^/]+) + /(?P[^/]+) + $", + ) + .unwrap() + }); + + let captures = LOCATOR_RE + .captures(s) + .ok_or_else(|| format_err!("Invalid Trino locator: {}", s))?; + + Ok(Self { + user: captures.name("user").unwrap().as_str().to_owned(), + host: captures.name("host").unwrap().as_str().to_owned(), + port: captures + .name("port") + .map(|m| m.as_str().parse::().unwrap()), + catalog: captures.name("catalog").unwrap().as_str().to_owned(), + schema: captures.name("schema").unwrap().as_str().to_owned(), + }) + } +} + +#[async_trait] +impl Locator for TrinoLocator { + fn target(&self) -> Target { + Target::Trino + } + + async fn driver(&self) -> Result> { + Ok(Box::new(TrinoDriver::from_locator(self)?)) + } +} + +/// A Trino driver. +pub struct TrinoDriver { + catalog: String, + schema: String, + client: Client, +} + +impl TrinoDriver { + /// Create a new Trino driver from a locator. + pub fn from_locator(locator: &TrinoLocator) -> Result { + let client = ClientBuilder::new(&locator.user, &locator.host) + .port(locator.port.unwrap_or(8080)) + .catalog(&locator.catalog) + .schema(&locator.schema) + .build() + .with_context(|| format!("Failed to connect to Trino: {}", locator))?; + Ok(Self { + client, + catalog: locator.catalog.clone(), + schema: locator.schema.clone(), + }) + } +} + +#[async_trait] +impl Driver for TrinoDriver { + fn target(&self) -> Target { + Target::Trino + } + + #[tracing::instrument(skip_all)] + async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()> { + debug!(%sql, "Executing native SQL statement"); + self.client + .execute(sql.to_owned()) + .await + .with_context(|| format!("Failed to execute SQL: {}", sql))?; + Ok(()) + } + + #[tracing::instrument(skip_all)] + async fn execute_ast(&mut self, ast: &ast::SqlProgram) -> Result<()> { + let rewritten = self.rewrite_ast(ast)?; + for sql in rewritten.extra_native_sql { + self.execute_native_sql_statement(&sql).await?; + } + + // We can only execute one statement at a time. + for statement in rewritten.ast.statements.node_iter() { + let sql = statement.emit_to_string(self.target()); + self.execute_native_sql_statement(&sql).await?; + } + Ok(()) + } + + fn transforms(&self) -> Vec> { + vec![Box::new(transforms::OrReplaceToDropIfExists)] + } + + #[tracing::instrument(skip(self))] + async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { + self.client + .execute(format!("DROP TABLE IF EXISTS {}", SQLite3Ident(table_name))) + .await + .with_context(|| format!("Failed to drop table: {}", table_name))?; + Ok(()) + } + + #[tracing::instrument(skip(self))] + async fn compare_tables(&mut self, result_table: &str, expected_table: &str) -> Result<()> { + self.compare_tables_impl(result_table, expected_table).await + } +} + +#[async_trait] +impl DriverImpl for TrinoDriver { + type Type = String; + type Value = serde_json::Value; + type Rows = Box>> + Send + Sync>; + + #[tracing::instrument(skip(self))] + async fn table_columns(&mut self, table_name: &str) -> Result>> { + #[derive(Debug, Presto)] + #[allow(non_snake_case)] + struct Col { + col: String, + ty: String, + } + + let sql = format!( + "SELECT column_name AS col, data_type AS ty + FROM information_schema.columns + WHERE table_catalog = {} AND table_schema = {} AND table_name = {}", + // TODO: Replace with real string escapes. + SQLite3String(&self.catalog), + SQLite3String(&self.schema), + SQLite3String(table_name) + ); + Ok(self + .client + .get_all::(sql) + .await + .with_context(|| format!("Failed to get columns for table: {}", table_name))? + .into_vec() + .into_iter() + .map(|c| Column { + name: c.col, + ty: c.ty, + }) + .collect()) + } + + #[tracing::instrument(skip(self))] + async fn query_table_sorted( + &mut self, + table_name: &str, + columns: &[Column], + ) -> Result { + let cols_sql = columns + .iter() + .map(|c| SQLite3Ident(&c.name).to_string()) + .collect::>() + .join(", "); + let sql = format!( + "SELECT {} FROM {} ORDER BY {}", + cols_sql, + SQLite3Ident(table_name), + cols_sql + ); + let rows = self + .client + .get_all::(sql) + .await + .with_context(|| format!("Failed to query table: {}", table_name))? + .into_vec() + .into_iter() + .map(|r| Ok(r.into_json())); + Ok(Box::new(rows)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_trino_locator() { + let locator: TrinoLocator = "trino://user@host:1234/catalog/schema" + .parse() + .expect("parse failed"); + assert_eq!(locator.user, "user"); + assert_eq!(locator.host, "host"); + assert_eq!(locator.port, Some(1234)); + assert_eq!(locator.catalog, "catalog"); + assert_eq!(locator.schema, "schema"); + + let locator: TrinoLocator = "trino://user@host/catalog/schema" + .parse() + .expect("parse failed"); + assert_eq!(locator.user, "user"); + assert_eq!(locator.host, "host"); + assert_eq!(locator.port, None); + assert_eq!(locator.catalog, "catalog"); + assert_eq!(locator.schema, "schema"); + } +} diff --git a/src/main.rs b/src/main.rs index da72e99..5099804 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod ast; mod cmd; mod drivers; mod errors; +mod transforms; mod util; use cmd::{ diff --git a/src/transforms/mod.rs b/src/transforms/mod.rs new file mode 100644 index 0000000..fadc932 --- /dev/null +++ b/src/transforms/mod.rs @@ -0,0 +1,31 @@ +//! Ways to transform an [`ast::SqlProgram`] into another `ast::SqlProgram`. +//! +//! These work entirely in terms of parsed BigQuery SQL. Normally, the ideal +//! transform should take a `SqlProgram` containing BigQuery-specific features, +//! and return one that uses the equivalent standard SQL features. This allows +//! us to use the same transforms for as many databases as possible. + +use crate::{ast, errors::Result}; + +pub use self::{ + or_replace_to_drop_if_exists::OrReplaceToDropIfExists, + rename_functions::{RenameFunctions, Udf}, +}; + +mod or_replace_to_drop_if_exists; +mod rename_functions; + +/// A transform that modifies an [`SqlProgram`]. +pub trait Transform { + /// Apply this transform to an [`SqlProgram`]. + /// + /// Returns a list of extra SQL statements that need to be executed before + /// the transformed program. These statements must be in the target dialect + /// of SQL, and typically include custom UDFs or similar temporary + /// definitions. + /// + /// A transform should only be used once, as it may modify itself in the + /// process of transforming the AST. To enforce this, the transform takes + /// `self: Box` rather than `&mut self`. + fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result>; +} diff --git a/src/transforms/or_replace_to_drop_if_exists.rs b/src/transforms/or_replace_to_drop_if_exists.rs new file mode 100644 index 0000000..49532de --- /dev/null +++ b/src/transforms/or_replace_to_drop_if_exists.rs @@ -0,0 +1,80 @@ +//! Transform `OR REPLACE` to the equivalent `DROP IF EXISTS`. + +use crate::{ + ast::{self, CreateTableStatement, CreateViewStatement, NodeOrSep}, + errors::Result, +}; + +use super::Transform; + +/// Transform `OR REPLACE` to the equivalent `DROP IF EXISTS`. +pub struct OrReplaceToDropIfExists; + +impl Transform for OrReplaceToDropIfExists { + fn transform(self: Box, sql_program: &mut ast::SqlProgram) -> Result> { + let old_statements = sql_program.statements.take(); + for mut node_or_sep in old_statements { + match &mut node_or_sep { + NodeOrSep::Node(ast::Statement::CreateTable(CreateTableStatement { + create_token, + or_replace: or_replace @ Some(_), + table_token, + table_name, + .. + })) => { + // Insert a `DROP TABLE IF EXISTS` statement before the `CREATE TABLE`. + sql_program.statements.push(ast::Statement::DropTable( + ast::DropTableStatement { + // For now, give DROP the same whitespace and source + // location as the original CREATE. + drop_token: create_token.with_token_str("DROP"), + table_token: table_token.clone(), + if_exists: Some(if_exists_clause(table_token)), + table_name: table_name.clone(), + }, + )); + + // Remove the `OR REPLACE` clause. + *or_replace = None; + } + NodeOrSep::Node(ast::Statement::CreateView(CreateViewStatement { + create_token, + or_replace: or_replace @ Some(_), + view_token, + view_name, + .. + })) => { + // Insert a `DROP VIEW IF EXISTS` statement before the `CREATE VIEW`. + sql_program + .statements + .push(ast::Statement::DropView(ast::DropViewStatement { + // For now, give DROP the same whitespace and source + // location as the original CREATE. + drop_token: create_token.with_token_str("DROP"), + view_token: view_token.clone(), + if_exists: Some(if_exists_clause(view_token)), + view_name: view_name.clone(), + })); + + // Remove the `OR REPLACE` clause. + *or_replace = None; + } + _ => {} + } + sql_program.statements.push_node_or_sep(node_or_sep); + } + Ok(vec![]) + } +} + +/// Genrate an `IF EXISTS` clause for a `DROP` statement. +fn if_exists_clause(token_for_span: &mut ast::Token) -> ast::IfExists { + // TODO: We really need to formalize how we create synthetic tokens. + ast::IfExists { + if_token: token_for_span.with_token_str("IF").ensure_ws().into_owned(), + exists_token: token_for_span + .with_token_str("EXISTS") + .ensure_ws() + .into_owned(), + } +} diff --git a/src/transforms/rename_functions.rs b/src/transforms/rename_functions.rs new file mode 100644 index 0000000..a2c65df --- /dev/null +++ b/src/transforms/rename_functions.rs @@ -0,0 +1,84 @@ +//! A simple tree-walker that renames functions to their Snowflake equivalents. + +use std::collections::HashMap; + +use derive_visitor::{DriveMut, VisitorMut}; + +use crate::{ + ast::{self, FunctionName, Identifier}, + errors::Result, +}; + +use super::Transform; + +/// A Snowflake UDF (user-defined function). +pub struct Udf { + pub decl: &'static str, + pub sql: &'static str, +} + +#[derive(VisitorMut)] +#[visitor(FunctionName(enter))] +pub struct RenameFunctions { + // Lookup table containing function replacements. + function_table: &'static phf::Map<&'static str, &'static str>, + + // Lookup table containing UDF replacements. + udf_table: &'static phf::Map<&'static str, &'static Udf>, + + // Format a UDF. + format_udf: &'static dyn Fn(&Udf) -> String, + + // UDFs that we need to create, if we haven't already. + udfs: HashMap, +} + +impl RenameFunctions { + /// Create a new `RenameFunctions` visitor. + pub fn new( + function_table: &'static phf::Map<&'static str, &'static str>, + udf_table: &'static phf::Map<&'static str, &'static Udf>, + format_udf: &'static dyn Fn(&Udf) -> String, + ) -> Self { + Self { + function_table, + udf_table, + format_udf, + udfs: HashMap::new(), + } + } + + fn enter_function_name(&mut self, function_name: &mut FunctionName) { + if let FunctionName::Function { function } = function_name { + let name = function.unescaped_bigquery().to_ascii_uppercase(); + if let Some(snowflake_name) = self.function_table.get(&name) { + // Rename the function. + let orig_ident = function_name.function_identifier(); + *function_name = FunctionName::Function { + function: Identifier { + token: orig_ident.token.with_token_str(snowflake_name), + text: snowflake_name.to_string(), + }, + }; + } else if let Some(udf) = self.udf_table.get(&name) { + // We'll need a UDF, so add it to our list it if isn't already + // there. + self.udfs.insert(name, udf); + } + } + } +} + +impl Transform for RenameFunctions { + fn transform(mut self: Box, sql_program: &mut ast::SqlProgram) -> Result> { + // Walk the AST, renaming functions and collecting UDFs. + sql_program.drive_mut(self.as_mut()); + + // Create any UDFs that we need. + let mut extra_sql = vec![]; + for udf in self.udfs.values() { + extra_sql.push((self.format_udf)(udf)); + } + Ok(extra_sql) + } +}