From 32f77c4efbd798691f045976f2c5d5a059d2aa07 Mon Sep 17 00:00:00 2001 From: Chase Willden Date: Sun, 17 Nov 2024 23:23:22 -0700 Subject: [PATCH 1/2] experimenting with sql! --- Cargo.lock | 5 +- njord/src/sqlite/select.rs | 36 +++++++- njord/tests/sqlite/select_test.rs | 57 ++++++++++++ njord_derive/Cargo.toml | 3 +- njord_derive/src/lib.rs | 149 +++++++++++++++++++++++++++++- 5 files changed, 242 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ec9106e..3f2cc410 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1390,6 +1390,7 @@ version = "0.4.0" dependencies = [ "proc-macro2", "quote", + "regex", "rusqlite", "syn 2.0.87", ] @@ -1760,9 +1761,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", diff --git a/njord/src/sqlite/select.rs b/njord/src/sqlite/select.rs index 89f4b4dd..29cc3ab0 100644 --- a/njord/src/sqlite/select.rs +++ b/njord/src/sqlite/select.rs @@ -54,9 +54,7 @@ use crate::util::{Join, JoinType}; /// # Returns /// /// A `SelectQueryBuilder` instance. -pub fn select<'a, T: Table + Default>( - columns: Vec>, -) -> SelectQueryBuilder<'a, T> { +pub fn select<'a, T: Table + Default>(columns: Vec>) -> SelectQueryBuilder<'a, T> { SelectQueryBuilder::new(columns) } @@ -348,7 +346,7 @@ impl<'a, T: Table + Default> SelectQueryBuilder<'a, T> { } /// Builds and executes the SELECT query. - /// + /// /// # Arguments /// /// * `conn` - A reference to the database connection. @@ -406,3 +404,33 @@ where self.build_query() } } + +pub fn raw_execute(sql: &str, conn: &Connection) -> Result> { + let mut binding = conn.prepare(sql)?; + let iter = binding.query_map((), |row| { + let mut instance = T::default(); + let columns = instance.get_column_fields(); + + for (index, column) in columns.iter().enumerate() { + let value = row.get::(index)?; + + let string_value = match value { + Value::Integer(val) => val.to_string(), + Value::Null => String::new(), + Value::Real(val) => val.to_string(), + Value::Text(val) => val.to_string(), + Value::Blob(val) => String::from_utf8_lossy(&val).to_string(), + }; + + instance.set_column_value(column, &string_value); + } + + Ok(instance) + })?; + + let result: Result> = iter + .map(|row_result| row_result.and_then(|row| Ok(row))) + .collect::>>(); + + result.map_err(|err| err.into()) +} diff --git a/njord/tests/sqlite/select_test.rs b/njord/tests/sqlite/select_test.rs index 1f5e246a..b6aaf9d7 100644 --- a/njord/tests/sqlite/select_test.rs +++ b/njord/tests/sqlite/select_test.rs @@ -2,6 +2,7 @@ use njord::condition::Condition; use njord::keys::AutoIncrementPrimaryKey; use njord::sqlite; use njord::{column::Column, condition::Value}; +use njord_derive::sql; use std::collections::HashMap; use std::path::Path; @@ -525,3 +526,59 @@ fn select_in() { Err(e) => panic!("Failed to SELECT: {:?}", e), }; } + +#[test] +fn sql_bang() { + let user_id = 1; + + let query = sql! { + SELECT * + FROM user + WHERE id = {user_id} + }; + + assert_eq!(query.to_string(), "SELECT * FROM user WHERE id = 1"); + + let complex_query = sql! { + SELECT a.company, COUNT(i.id) AS total_impressions, COUNT(DISTINCT i.ip_address) AS unique_impressions + FROM impressions i + INNER JOIN cached_content c ON c.content_hash = i.content_hash + INNER JOIN ads a ON a.id = c.ad_id + GROUP BY a.company; + }; + + assert_eq!( + complex_query.to_string(), + "SELECT a.company, COUNT (i.id) AS total_impressions, COUNT (DISTINCT i.ip_address) AS unique_impressions \ + FROM impressions i \ + INNER JOIN cached_content c ON c.content_hash = i.content_hash \ + INNER JOIN ads a ON a.id = c.ad_id \ + GROUP BY a.company;" + ); +} + +#[test] +fn raw_execute() { + let db_relative_path = "./db/select.db"; + let db_path = Path::new(&db_relative_path); + let conn = sqlite::open(db_path); + + let username = "mjovanc"; + + let query = sql! { + SELECT * + FROM users + WHERE username = {username} + }; + + match conn { + Ok(ref c) => { + let result = sqlite::select::raw_execute::(&query, c); + match result { + Ok(r) => assert_eq!(r.len(), 2), + Err(e) => panic!("Failed to SELECT: {:?}", e), + }; + } + Err(e) => panic!("Failed to SELECT: {:?}", e), + }; +} diff --git a/njord_derive/Cargo.toml b/njord_derive/Cargo.toml index f4481527..83fb8882 100644 --- a/njord_derive/Cargo.toml +++ b/njord_derive/Cargo.toml @@ -19,5 +19,6 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.89" quote = "1.0" -syn = "2.0.87" +syn = { version = "2.0.87", features = ["full"] } rusqlite = { version = "0.32.1", features = ["bundled"] } +regex = "1.11.1" \ No newline at end of file diff --git a/njord_derive/src/lib.rs b/njord_derive/src/lib.rs index 6e576cf7..35ee1f50 100644 --- a/njord_derive/src/lib.rs +++ b/njord_derive/src/lib.rs @@ -32,9 +32,11 @@ extern crate proc_macro; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; -use quote::quote; use syn::{parse_macro_input, DeriveInput, FieldsNamed}; +use proc_macro2::{Delimiter, TokenTree as TokenTree2}; +use quote::quote; + use util::{extract_table_name, has_default_impl}; mod util; @@ -218,3 +220,148 @@ pub fn table_derive(input: TokenStream) -> TokenStream { output.into() } + +/// The procedural macro `sql!` takes a SQL-like syntax and transforms it into a string. +// #[proc_macro] +// pub fn sql(input: TokenStream) -> TokenStream { +// /* +// GOAL: +// let id = 1; + +// let query = sql! { +// SELECT * +// FROM user +// WHERE id = {id} +// }; +// */ +// let input_string = input.to_string(); + +// // Remove the outer quotes +// let input_string = input_string.trim_matches(|c| c == '"' || c == '`' || c == '\''); + +// let expanded = quote! { +// { +// #input_string +// } +// }; + +// expanded.into() +// } + +#[proc_macro] +pub fn sql(input: TokenStream) -> TokenStream { + let input: proc_macro2::TokenStream = input.into(); + let mut tokens = input.into_iter().peekable(); + let mut sql_parts = Vec::new(); + let mut expressions = Vec::new(); + let mut param_types = Vec::new(); + let mut current_sql = String::new(); + let mut last_token_type = TokenType::Other; + + #[derive(PartialEq, Clone)] + enum TokenType { + Dot, + OpenParen, + CloseParen, + Operator, + Other, + } + + while let Some(token) = tokens.next() { + match token { + TokenTree2::Group(group) if group.delimiter() == Delimiter::Brace => { + if !current_sql.is_empty() { + sql_parts.push(current_sql); + current_sql = String::new(); + } + + // Parse the expression to determine its type + let expr = group.stream(); + let expr_str = expr.to_string(); + + // Check if it's an identifier (likely a string variable) + let needs_quotes = !expr_str.contains("as") + && !expr_str.contains("::") + && !expr_str.starts_with("Some") + && !expr_str.parse::().is_ok() + && !expr_str.parse::().is_ok(); + + if needs_quotes { + sql_parts.push("'{}'".to_string()); + } else { + sql_parts.push("{}".to_string()); + } + + expressions.push(expr); + param_types.push(needs_quotes); + last_token_type = TokenType::Other; + } + token => { + let token_str = token.to_string(); + let current_token_type = match token_str.as_str() { + "." => TokenType::Dot, + "(" => TokenType::OpenParen, + ")" => TokenType::CloseParen, + "=" | ">" | "<" | ">=" | "<=" | "!=" => TokenType::Operator, + _ => TokenType::Other, + }; + match current_token_type { + TokenType::Dot => { + current_sql.push('.'); + } + TokenType::OpenParen => { + current_sql.push('('); + } + TokenType::CloseParen => { + current_sql.push(')'); + if let Some(next) = tokens.peek() { + let next_str = next.to_string(); + if !matches!(next_str.as_str(), "," | "." | ")" | ";") { + current_sql.push(' '); + } + } + } + TokenType::Operator => { + if !current_sql.ends_with(' ') { + current_sql.push(' '); + } + current_sql.push_str(&token_str); + current_sql.push(' '); + } + TokenType::Other => { + let needs_space = !current_sql.is_empty() + && !current_sql.ends_with(' ') + && !matches!(last_token_type, TokenType::Dot | TokenType::OpenParen) + && token_str != "," + && token_str != ";"; + if needs_space { + current_sql.push(' '); + } + current_sql.push_str(&token_str); + if token_str == "," { + current_sql.push(' '); + } + } + } + last_token_type = current_token_type; + } + } + } + + if !current_sql.is_empty() { + sql_parts.push(current_sql); + } + + let sql_format = sql_parts.join(""); + let expanded = if expressions.is_empty() { + quote! { + #sql_format.to_string() + } + } else { + quote! { + format!(#sql_format #(,#expressions)*) + } + }; + + expanded.into() +} From 55f28445366a4f2e04d5198a4ef5220ff64405f7 Mon Sep 17 00:00:00 2001 From: Chase Willden Date: Sun, 17 Nov 2024 23:31:00 -0700 Subject: [PATCH 2/2] fix test --- njord/tests/sqlite/select_test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/njord/tests/sqlite/select_test.rs b/njord/tests/sqlite/select_test.rs index b6aaf9d7..b77c1de5 100644 --- a/njord/tests/sqlite/select_test.rs +++ b/njord/tests/sqlite/select_test.rs @@ -537,7 +537,7 @@ fn sql_bang() { WHERE id = {user_id} }; - assert_eq!(query.to_string(), "SELECT * FROM user WHERE id = 1"); + assert_eq!(query.to_string(), "SELECT * FROM user WHERE id = '1'"); let complex_query = sql! { SELECT a.company, COUNT(i.id) AS total_impressions, COUNT(DISTINCT i.ip_address) AS unique_impressions