From 41c9ede5f9c21dcda9d2f7c050fc31bf18f74de6 Mon Sep 17 00:00:00 2001 From: Michael Cuffaro Date: Mon, 16 Sep 2024 09:03:47 -0400 Subject: [PATCH] add helper function for sql bindings, get_mixed_query_params() --- src/toolkit.rs | 46 +++++++++++++++++++ src/validate.rs | 117 ++++++++++++++++++++++++++++++------------------ src/valve.rs | 30 ++++++++++--- 3 files changed, 142 insertions(+), 51 deletions(-) diff --git a/src/toolkit.rs b/src/toolkit.rs index 0a8a98a..13e94df 100644 --- a/src/toolkit.rs +++ b/src/toolkit.rs @@ -156,6 +156,14 @@ pub enum QueryAsIfKind { Replace, } +/// Used to represent a generic query parameter for binding to a SQLx query. +pub enum QueryParam { + Numeric(f64), + Real(f64), + Integer(i32), + String(String), +} + /// Given a string representing the location of a database, return a database connection pool. pub async fn get_pool_from_connection_string(database: &str) -> Result { let connection_options; @@ -3743,6 +3751,44 @@ pub fn compile_condition( } } +/// Given a list of [SerdeValue]s and the SQL type of the column that they come from, return +/// a SQL string consisting of a comma-separated list of [SQL_PARAM] placeholders to use for the +/// binding, and the list of parameters that will need to be bound to the string before executing. +pub fn get_mixed_query_params( + values: &Vec, + sql_type: &str, +) -> (String, Vec) { + let mut param_values = vec![]; + let mut param_placeholders = vec![]; + + for value in values { + param_placeholders.push(SQL_PARAM); + let param_value = value + .as_str() + .expect(&format!("'{}' is not a string", value)); + if sql_type == "numeric" { + let numeric_value: f64 = param_value + .parse() + .expect(&format!("{param_value} is not numeric")); + param_values.push(QueryParam::Numeric(numeric_value)); + } else if sql_type == "integer" { + let integer_value: i32 = param_value + .parse() + .expect(&format!("{param_value} is not an integer")); + param_values.push(QueryParam::Integer(integer_value)); + } else if sql_type == "real" { + let real_value: f64 = param_value + .parse() + .expect(&format!("{param_value} is not a real")); + param_values.push(QueryParam::Real(real_value)); + } else { + param_values.push(QueryParam::String(param_value.to_string())); + } + } + + (param_placeholders.join(", "), param_values) +} + /// Given the config map, the name of a datatype, and a database connection pool used to determine /// the database type, climb the datatype tree (as required), and return the first 'SQL type' found. /// If there is no SQL type defined for the given datatype, return TEXT. diff --git a/src/validate.rs b/src/validate.rs index 428f168..d8d63e9 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -4,9 +4,9 @@ use crate::{ ast::Expression, toolkit::{ cast_sql_param_from_text, get_column_value, get_column_value_as_string, - get_datatype_ancestors, get_sql_type_from_global_config, get_table_options_from_config, - get_value_type, is_sql_type_error, local_sql_syntax, ColumnRule, CompiledCondition, - QueryAsIf, QueryAsIfKind, ValueType, + get_datatype_ancestors, get_mixed_query_params, get_sql_type_from_global_config, + get_table_options_from_config, get_value_type, is_sql_type_error, local_sql_syntax, + ColumnRule, CompiledCondition, QueryAsIf, QueryAsIfKind, QueryParam, ValueType, }, valve::{ ValveCell, ValveCellMessage, ValveConfig, ValveRow, ValveRuleConfig, ValveTreeConstraint, @@ -309,34 +309,42 @@ pub async fn validate_rows_constraints( let sql_type = get_sql_type_from_global_config(config, &fkey.ftable, &fkey.fcolumn, pool) .to_lowercase(); - - // TODO: Here. - let values_str = received_values + let values = received_values .get(&*fkey.column) .unwrap() .iter() - .map(|value| { - let value = value - .as_str() - .expect(&format!("'{}' is not a string", value)); - if vec!["integer", "numeric", "real"].contains(&sql_type.as_str()) { - format!("{}", value) - } else { - format!("'{}'", value.replace("'", "''")) - } + .filter(|value| { + !is_sql_type_error( + &sql_type, + value + .as_str() + .expect(&format!("'{}' is not a string", value)), + ) }) - .filter(|value| !is_sql_type_error(&sql_type, value)) - .collect::>() - .join(", "); + .cloned() + .collect::>(); + let (lookup_sql, param_values) = get_mixed_query_params(&values, &sql_type); // Foreign keys always correspond to columns with unique constraints so we do not // need to use the keyword 'DISTINCT' when querying the normal version of the table: - let sql = format!( - r#"SELECT "{}" FROM "{}" WHERE "{}" IN ({})"#, - fkey.fcolumn, fkey.ftable, fkey.fcolumn, values_str + let sql = local_sql_syntax( + pool, + &format!( + r#"SELECT "{}" FROM "{}" WHERE "{}" IN ({})"#, + fkey.fcolumn, fkey.ftable, fkey.fcolumn, lookup_sql + ), ); + let mut query = sqlx_query(&sql); + for param_value in ¶m_values { + match param_value { + QueryParam::Integer(p) => query = query.bind(p), + QueryParam::Numeric(p) => query = query.bind(p), + QueryParam::Real(p) => query = query.bind(p), + QueryParam::String(p) => query = query.bind(p), + } + } - let allowed_values = sqlx_query(&sql) + let allowed_values = query .fetch_all(pool) .await? .iter() @@ -353,11 +361,23 @@ pub async fn validate_rows_constraints( // The conflict table has no keys other than on row_number so in principle // it could have duplicate values of the foreign constraint, therefore we // add the DISTINCT keyword here: - let sql = format!( - r#"SELECT DISTINCT "{}" FROM "{}_conflict" WHERE "{}" IN ({})"#, - fkey.fcolumn, fkey.ftable, fkey.fcolumn, values_str + let sql = local_sql_syntax( + pool, + &format!( + r#"SELECT DISTINCT "{}" FROM "{}_conflict" WHERE "{}" IN ({})"#, + fkey.fcolumn, fkey.ftable, fkey.fcolumn, lookup_sql + ), ); - sqlx_query(&sql) + let mut query = sqlx_query(&sql); + for param_value in ¶m_values { + match param_value { + QueryParam::Integer(p) => query = query.bind(p), + QueryParam::Numeric(p) => query = query.bind(p), + QueryParam::Real(p) => query = query.bind(p), + QueryParam::String(p) => query = query.bind(p), + } + } + query .fetch_all(pool) .await? .iter() @@ -411,33 +431,42 @@ pub async fn validate_rows_constraints( } }; - // TODO: Here. let sql_type = get_sql_type_from_global_config(config, &table, &column, pool).to_lowercase(); - let values_str = received_values + let values = received_values .get(&*column) .unwrap() .iter() - .map(|value| { - let value = value - .as_str() - .expect(&format!("'{}' is not a string", value)); - if vec!["integer", "numeric", "real"].contains(&sql_type.as_str()) { - format!("{}", value) - } else { - format!("'{}'", value.replace("'", "''")) - } + .filter(|value| { + !is_sql_type_error( + &sql_type, + value + .as_str() + .expect(&format!("'{}' is not a string", value)), + ) }) - .filter(|value| !is_sql_type_error(&sql_type, value)) - .collect::>() - .join(", "); + .cloned() + .collect::>(); + let (lookup_sql, param_values) = get_mixed_query_params(&values, &sql_type); - let sql = format!( - r#"SELECT {} "{}" FROM "{}" WHERE "{}" IN ({})"#, - query_modifier, column, query_table, column, values_str + let sql = local_sql_syntax( + pool, + &format!( + r#"SELECT {} "{}" FROM "{}" WHERE "{}" IN ({})"#, + query_modifier, column, query_table, column, lookup_sql + ), ); + let mut query = sqlx_query(&sql); + for param_value in ¶m_values { + match param_value { + QueryParam::Integer(p) => query = query.bind(p), + QueryParam::Numeric(p) => query = query.bind(p), + QueryParam::Real(p) => query = query.bind(p), + QueryParam::String(p) => query = query.bind(p), + } + } - let forbidden_values = sqlx_query(&sql) + let forbidden_values = query .fetch_all(pool) .await? .iter() diff --git a/src/valve.rs b/src/valve.rs index 0f48f5e..5a95a96 100644 --- a/src/valve.rs +++ b/src/valve.rs @@ -986,7 +986,6 @@ impl Valve { } } } else { - println!("FAVVOOOM"); let sql = format!( r#"SELECT ccu.table_name AS foreign_table_name, @@ -1961,21 +1960,38 @@ impl Valve { // Collect the paths and possibly the options of all of the tables that were requested to be // saved: let options_enabled = self.column_enabled_in_db("table", "options").await?; - // TODO: Here. + + // Build the query to get the path and options info from the table table: + let mut params = vec![]; + let sql_param_str = tables + .iter() + .map(|table| { + params.push(table); + SQL_PARAM.to_string() + }) + .collect::>() + .join(", "); let sql = { if options_enabled { format!( - r#"SELECT "table", "path", "options" FROM "table" WHERE "table" IN ('{}')"#, - tables.join("', '") + r#"SELECT "table", "path", "options" FROM "table" WHERE "table" IN ({})"#, + sql_param_str ) } else { format!( - r#"SELECT "table", "path" FROM "table" WHERE "table" IN ('{}')"#, - tables.join("', '") + r#"SELECT "table", "path" FROM "table" WHERE "table" IN ({})"#, + sql_param_str ) } }; - let mut stream = sqlx_query(&sql).fetch(&self.pool); + let sql = local_sql_syntax(&self.pool, &sql); + let mut query = sqlx_query(&sql); + for param in ¶ms { + query = query.bind(param); + } + + // Query the db: + let mut stream = query.fetch(&self.pool); while let Some(row) = stream.try_next().await? { let table = row .try_get::<&str, &str>("table")