From a245f5575b99eb56a7b83f3612e482876e09e6a1 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Tue, 16 Jul 2024 17:48:32 +0200 Subject: [PATCH 01/18] DB except serialization --- Cargo.toml | 3 +- README.md | 5 ++- build.rs | 25 +++++++++++ resources/sql/SELECT_TABLES | 1 + src/config.rs | 10 +++++ src/db.rs | 86 +++++++++++++++++++++++++++++++++++++ src/errors.rs | 8 ++++ src/main.rs | 56 ++++++++++++++++++++++-- 8 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 resources/sql/SELECT_TABLES create mode 100644 src/db.rs diff --git a/Cargo.toml b/Cargo.toml index 8d0e486..85c1200 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" [dependencies] base64 = "0.22.1" http = "0.2" -reqwest = { version = "0.11", default_features = false, features = ["json", "default-tls"] } +reqwest = { version = "0.11", default-features = false, features = ["json", "default-tls"] } serde = { version = "1.0.152", features = ["serde_derive"] } serde_json = "1.0" thiserror = "1.0.38" @@ -21,6 +21,7 @@ laplace_rs = {git = "https://github.com/samply/laplace-rs.git", tag = "v0.3.0" } uuid = "1.8.0" rand = { default-features = false, version = "0.8.5" } futures-util = { version = "0.3", default-features = false, features = ["std"] } +sqlx = { version = "0.7", features = [ "runtime-tokio", "postgres", "macros"] } # Logging tracing = { version = "0.1.37", default_features = false } diff --git a/README.md b/README.md index 2de2579..fc21c93 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,10 @@ EPSILON = "0.1" # Privacy budget parameter for obfuscating the counts in the str ROUNDING_STEP = "10" # The granularity of the rounding of the obfuscated values, has no effect if OBFUSCATE = "no", default value: 10 PROJECTS_NO_OBFUSCATION = "exliquid;dktk_supervisors;exporter;ehds2" # Projects for which the results are not to be obfuscated, separated by ;, default value: "exliquid;dktk_supervisors;exporter;ehds2" QUERIES_TO_CACHE = "queries_to_cache.conf" # The path to a file containing base64 encoded queries whose results are to be cached. If not set, no results are cached -PROVIDER = "name" #OMOP provider name -PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded OMOP provider icon +PROVIDER = "name" #EUCAIM provider name +PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded EUCAIM provider icon AUTH_HEADER = "ApiKey XXXX" #Authorization header +DB_CONNECTION_STRING = "postgresql://postgres:Test.123@localhost:5432/postgres" # Database connection string ``` Obfuscating zero counts is by default switched off. To enable obfuscating zero counts, set the env. variable `OBFUSCATE_ZERO`. diff --git a/build.rs b/build.rs index 76667b0..b3a11b1 100644 --- a/build.rs +++ b/build.rs @@ -41,6 +41,30 @@ fn build_cqlmap() { ).unwrap(); } +fn build_sqlmap() { + let path = Path::new(&env::var("OUT_DIR").unwrap()).join("sql_replace_map.rs"); + let mut file = BufWriter::new(File::create(path).unwrap()); + + write!(&mut file, r#" + static SQL_REPLACE_MAP: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| {{ + let mut map = HashMap::new(); + "#).unwrap(); + + for sqlfile in std::fs::read_dir(Path::new("resources/sql")).unwrap() { + let sqlfile = sqlfile.unwrap(); + let sqlfilename = sqlfile.file_name().to_str().unwrap().to_owned(); + let sqlcontent = std::fs::read_to_string(sqlfile.path()).unwrap(); + write!(&mut file, r####" + map.insert(r###"{sqlfilename}"###, r###"{sqlcontent}"###); + "####).unwrap(); + } + + writeln!(&mut file, " + map + }});" + ).unwrap(); +} + fn main() { build_data::set_GIT_COMMIT_SHORT(); build_data::set_GIT_DIRTY(); @@ -51,4 +75,5 @@ fn main() { println!("cargo:rustc-env=SAMPLY_USER_AGENT=Samply.Focus.{}/{}", env!("CARGO_PKG_NAME"), version()); build_cqlmap(); + build_sqlmap(); } diff --git a/resources/sql/SELECT_TABLES b/resources/sql/SELECT_TABLES new file mode 100644 index 0000000..c59f3b3 --- /dev/null +++ b/resources/sql/SELECT_TABLES @@ -0,0 +1 @@ +SELECT * FROM pg_catalog.pg_tables \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index e736fba..4a69eeb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,6 +21,8 @@ pub enum Obfuscate { pub enum EndpointType { Blaze, Omop, + BlazeAndSql, + Sql, } impl fmt::Display for EndpointType { @@ -28,6 +30,8 @@ impl fmt::Display for EndpointType { match self { EndpointType::Blaze => write!(f, "blaze"), EndpointType::Omop => write!(f, "omop"), + EndpointType::BlazeAndSql => write!(f, "blaze_sql"), + EndpointType::Sql => write!(f, "sql"), } } } @@ -151,6 +155,10 @@ struct CliArgs { #[clap(long, env, value_parser)] auth_header: Option, + /// Database connection string + #[clap(long, env, value_parser)] + db_connection_string: Option, + } pub(crate) struct Config { @@ -178,6 +186,7 @@ pub(crate) struct Config { pub provider: Option, pub provider_icon: Option, pub auth_header: Option, + pub db_connection_string: Option, } impl Config { @@ -219,6 +228,7 @@ impl Config { provider: cli_args.provider, provider_icon: cli_args.provider_icon, auth_header: cli_args.auth_header, + db_connection_string: cli_args.db_connection_string, client, }; Ok(config) diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..4e77b6b --- /dev/null +++ b/src/db.rs @@ -0,0 +1,86 @@ +use sqlx::{postgres::PgPoolOptions, PgPool, postgres::PgRow}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::{error, info}; +use crate::errors::FocusError; +use crate::util; + +#[derive(Serialize, Deserialize, Debug, Default, Clone)] +pub struct SqlQuery { + pub payload: String, +} + +include!(concat!(env!("OUT_DIR"), "/sql_replace_map.rs")); + +pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result { + info!("Trying to establish a PostgreSQL connection pool"); + + let mut attempts = 0; + let mut err: Option = None; + + while attempts < num_attempts { + info!("Attempt to connect to PostgreSQL {} of {}", attempts + 1, num_attempts); + match PgPoolOptions::new() + .max_connections(10) + .connect(&pg_url) + .await + { + Ok(pg_con_pool) => { + info!("PostgreSQL connection successfull"); + return Ok(pg_con_pool) + }, + Err(e) => { + error!("Failed to connect to PostgreSQL. Attempt {} of {}: {}", attempts + 1, num_attempts, e); + err = Some(FocusError::CannotConnectToDatabase(e.to_string())); + } + } + attempts += 1; + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + Err(err.unwrap_or_else(|| FocusError::CannotConnectToDatabase("Failed to connect to PostgreSQL".into()))) +} + +pub async fn healthcheck(pool: &PgPool) -> bool { + + let res = sqlx::query(include_str!("../resources/sql/SELECT_TABLES")) + .fetch_all(pool) + .await; + if let Ok(_) = res {true} else {false} +} + +pub async fn run_query(pool: &PgPool, query: &str) -> Result, FocusError> { + + sqlx::query(query) + .fetch_all(pool) + .await.map_err( FocusError::ErrorExecutingQuery) +} + +pub async fn process_sql_task(pool: &PgPool, encoded: &str) -> Result, FocusError>{ + let decoded = util::base64_decode(encoded)?; + let key = String::from_utf8(decoded).map_err(FocusError::ErrorConvertingToString)?; + let key = key.as_str(); + let sql_query = SQL_REPLACE_MAP.get(&(key.clone())); + if sql_query.is_none(){ + return Err(FocusError::QueryNotAllowed(key.into())); + } + let query = sql_query.unwrap(); + + run_query(pool, query).await + +} + + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + #[ignore] //TODO mock DB + async fn connect() { + let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1).await.unwrap(); + + assert!(healthcheck(&pool).await); + } +} + diff --git a/src/errors.rs b/src/errors.rs index 49acc63..4f430b7 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -60,6 +60,14 @@ pub enum FocusError { MissingExporterEndpoint, #[error("Missing Exporter Task Type")] MissingExporterTaskType, + #[error("Cannot connect to database: {0}")] + CannotConnectToDatabase(String), + #[error("Error executing query: {0}")] + ErrorExecutingQuery(sqlx::Error), + #[error("Error converting to string: {0}")] + ErrorConvertingToString(std::string::FromUtf8Error), + #[error("Query not allowed: {0}")] + QueryNotAllowed(String), } impl FocusError { diff --git a/src/main.rs b/src/main.rs index 1e4e423..ec9ca58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ mod task_processing; mod util; mod projects; mod exporter; +mod db; use base64::engine::general_purpose; @@ -21,13 +22,16 @@ use beam_lib::{TaskRequest, TaskResult}; use futures_util::future::BoxFuture; use futures_util::FutureExt; use laplace_rs::ObfCache; +use sqlx::PgPool; use tokio::sync::Mutex; + use crate::blaze::{parse_blaze_query_payload_ast, AstQuery}; use crate::config::EndpointType; use crate::util::{base64_decode, is_cql_tampered_with, obfuscate_counts_mr}; use crate::{config::CONFIG, errors::FocusError}; use blaze::CqlQuery; +use db::SqlQuery; use std::collections::HashMap; use std::ops::DerefMut; @@ -52,7 +56,7 @@ type BeamResult = TaskResult; #[serde(tag = "lang", rename_all = "lowercase")] enum Language { Cql(CqlQuery), - Ast(AstQuery) + Ast(AstQuery), } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -116,9 +120,22 @@ pub async fn main() -> ExitCode { } async fn main_loop() -> ExitCode { + let db_pool = if let Some(connection_string) = CONFIG.db_connection_string.clone() { + match db::get_pg_connection_pool(&connection_string, 8).await { + Err(e) => { + error!("Error connecting to database: {}, {}", connection_string, e); + return ExitCode::from(8); + }, + Ok(pool) => Some(pool), + } + } else { + None + }; let endpoint_service_available: fn() -> BoxFuture<'static, bool> = match CONFIG.endpoint_type { EndpointType::Blaze => || blaze::check_availability().boxed(), EndpointType::Omop => || async { true }.boxed(), // TODO health check + EndpointType::BlazeAndSql => || blaze::check_availability().boxed(), //TODO SQL health check + EndpointType::Sql => || async { true }.boxed(), // TODO health check }; let mut failures = 0; while !(beam::check_availability().await && endpoint_service_available().await) { @@ -144,12 +161,13 @@ async fn main_loop() -> ExitCode { task_processing::process_tasks(move |task| { let obf_cache = obf_cache.clone(); let report_cache = report_cache.clone(); - process_task(task, obf_cache, report_cache).boxed_local() + process_task(db_pool.clone(), task, obf_cache, report_cache).boxed_local() }).await; ExitCode::FAILURE } async fn process_task( + db_pool: Option, task: &BeamTask, obf_cache: Arc>, report_cache: Arc>, @@ -189,6 +207,37 @@ async fn process_task( }; run_cql_query(task, &query, obf_cache, report_cache, metadata.project, generated_from_ast).await + } else if CONFIG.endpoint_type == EndpointType::BlazeAndSql { + let mut generated_from_ast: bool = false; + let data = base64_decode(&task.body)?; + let query_maybe: Result = serde_json::from_slice(&(data.clone())); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool{ + let result = db::process_sql_task(&pool, &(sql_query.payload)).await; + if let Ok(rows) = result { + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.clone().from], + task.id, + "".into(), + )) + } else {return Err(FocusError::CannotConnectToDatabase("SQL task but no connection String in config".into()));} + } + else { + return Err(FocusError::CannotConnectToDatabase("SQL task but no connection String in config".into())); + } + } else { + + let query: CqlQuery = match serde_json::from_slice::(&data)? { + Language::Cql(cql_query) => cql_query, + Language::Ast(ast_query) => { + generated_from_ast = true; + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast(&ast_query.payload)?)?)? + } + }; + run_cql_query(task, &query, obf_cache, report_cache, metadata.project, generated_from_ast).await + } } else if CONFIG.endpoint_type == EndpointType::Omop { let decoded = util::base64_decode(&task.body)?; let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = @@ -437,4 +486,5 @@ mod test { assert_eq!(metadata.task_type, Some(exporter::TaskType::Execute)); } -} \ No newline at end of file +} + From 4400fa1a0cc1b39cf25c9c1e8dffcab95afe925e Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Wed, 17 Jul 2024 19:25:12 +0200 Subject: [PATCH 02/18] serialization of DB result --- Cargo.toml | 3 +- README.md | 31 ++++++----- src/config.rs | 26 ++++----- src/db.rs | 102 ++++++++++++++++++++++++----------- src/errors.rs | 2 + src/main.rs | 143 +++++++++++++++++++++++++++++++++++--------------- 6 files changed, 209 insertions(+), 98 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 85c1200..dd59e6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,8 @@ laplace_rs = {git = "https://github.com/samply/laplace-rs.git", tag = "v0.3.0" } uuid = "1.8.0" rand = { default-features = false, version = "0.8.5" } futures-util = { version = "0.3", default-features = false, features = ["std"] } -sqlx = { version = "0.7", features = [ "runtime-tokio", "postgres", "macros"] } +sqlx = { version = "0.7.4", features = [ "runtime-tokio", "postgres", "macros", "chrono"] } +sqlx-pgrow-serde = "0.2.0" # Logging tracing = { version = "0.1.37", default_features = false } diff --git a/README.md b/README.md index e35f1f7..e004c39 100644 --- a/README.md +++ b/README.md @@ -34,20 +34,20 @@ BEAM_APP_ID_LONG = "app1.broker.example.com" ### Optional variables ```bash -RETRY_COUNT = "32" # The maximum number of retries for beam and blaze healthchecks, default value: 32 -ENDPOINT_TYPE = "blaze" # Type of the endpoint, allowed values: "blaze", "omop", default value: "blaze" +RETRY_COUNT = "32" # The maximum number of retries for beam and blaze healthchecks; default value: 32 +ENDPOINT_TYPE = "blaze" # Type of the endpoint, allowed values: "blaze", "omop", "sql", "blaze-and-sql"; default value: "blaze" EXPORTER_URL = " https://exporter.site/" # The exporter URL -OBFUSCATE = "yes" # Should the results be obfuscated - the "master switch", allowed values: "yes", "no", default value: "yes" -OBFUSCATE_BELOW_10_MODE = "1" # The mode of obfuscating values below 10: 0 - return zero, 1 - return ten, 2 - obfuscate using Laplace distribution and rounding, has no effect if OBFUSCATE = "no", default value: 1 -DELTA_PATIENT = "1." # Sensitivity parameter for obfuscating the counts in the Patient stratifier, has no effect if OBFUSCATE = "no", default value: 1 -DELTA_SPECIMEN = "20." # Sensitivity parameter for obfuscating the counts in the Specimen stratifier, has no effect if OBFUSCATE = "no", default value: 20 -DELTA_DIAGNOSIS = "3." # Sensitivity parameter for obfuscating the counts in the Diagnosis stratifier, has no effect if OBFUSCATE = "no", default value: 3 -DELTA_PROCEDURES = "1.7" # Sensitivity parameter for obfuscating the counts in the Procedures stratifier, has no effect if OBFUSCATE = "no", default value: 1.7 -DELTA_MEDICATION_STATEMENTS = "2.1" # Sensitivity parameter for obfuscating the counts in the Medication Statements stratifier, has no effect if OBFUSCATE = "no", default value: 2.1 -DELTA_HISTO = "20." # Sensitivity parameter for obfuscating the counts in the Histo stratifier, has no effect if OBFUSCATE = "no", default value: 20 -EPSILON = "0.1" # Privacy budget parameter for obfuscating the counts in the stratifiers, has no effect if OBFUSCATE = "no", default value: 0.1 -ROUNDING_STEP = "10" # The granularity of the rounding of the obfuscated values, has no effect if OBFUSCATE = "no", default value: 10 -PROJECTS_NO_OBFUSCATION = "exliquid;dktk_supervisors;exporter;ehds2" # Projects for which the results are not to be obfuscated, separated by ;, default value: "exliquid;dktk_supervisors;exporter;ehds2" +OBFUSCATE = "yes" # Should the results be obfuscated - the "master switch", allowed values: "yes", "no"; default value: "yes" +OBFUSCATE_BELOW_10_MODE = "1" # The mode of obfuscating values below 10: 0 - return zero, 1 - return ten, 2 - obfuscate using Laplace distribution and rounding, has no effect if OBFUSCATE = "no"; default value: 1 +DELTA_PATIENT = "1." # Sensitivity parameter for obfuscating the counts in the Patient stratifier, has no effect if OBFUSCATE = "no"; default value: 1 +DELTA_SPECIMEN = "20." # Sensitivity parameter for obfuscating the counts in the Specimen stratifier, has no effect if OBFUSCATE = "no"; default value: 20 +DELTA_DIAGNOSIS = "3." # Sensitivity parameter for obfuscating the counts in the Diagnosis stratifier, has no effect if OBFUSCATE = "no"; default value: 3 +DELTA_PROCEDURES = "1.7" # Sensitivity parameter for obfuscating the counts in the Procedures stratifier, has no effect if OBFUSCATE = "no"; default value: 1.7 +DELTA_MEDICATION_STATEMENTS = "2.1" # Sensitivity parameter for obfuscating the counts in the Medication Statements stratifier, has no effect if OBFUSCATE = "no"; default value: 2.1 +DELTA_HISTO = "20." # Sensitivity parameter for obfuscating the counts in the Histo stratifier, has no effect if OBFUSCATE = "no"; default value: 20 +EPSILON = "0.1" # Privacy budget parameter for obfuscating the counts in the stratifiers, has no effect if OBFUSCATE = "no"; default value: 0.1 +ROUNDING_STEP = "10" # The granularity of the rounding of the obfuscated values, has no effect if OBFUSCATE = "no"; default value: 10 +PROJECTS_NO_OBFUSCATION = "exliquid;dktk_supervisors;exporter;ehds2" # Projects for which the results are not to be obfuscated, separated by ";" ; default value: "exliquid;dktk_supervisors;exporter;ehds2" QUERIES_TO_CACHE = "queries_to_cache.conf" # The path to a file containing base64 encoded queries whose results are to be cached. If not set, no results are cached PROVIDER = "name" #EUCAIM provider name PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded EUCAIM provider icon @@ -81,6 +81,11 @@ Creating a sample task containing an abstract syntax tree (AST) query using curl curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"bbmri"},"body":"eyJsYW5nIjoiYXN0IiwicGF5bG9hZCI6ImV5SmhjM1FpT25zaWIzQmxjbUZ1WkNJNklrOVNJaXdpWTJocGJHUnlaVzRpT2x0N0ltOXdaWEpoYm1RaU9pSkJUa1FpTENKamFHbHNaSEpsYmlJNlczc2liM0JsY21GdVpDSTZJazlTSWl3aVkyaHBiR1J5Wlc0aU9sdDdJbXRsZVNJNkltZGxibVJsY2lJc0luUjVjR1VpT2lKRlVWVkJURk1pTENKemVYTjBaVzBpT2lJaUxDSjJZV3gxWlNJNkltMWhiR1VpZlN4N0ltdGxlU0k2SW1kbGJtUmxjaUlzSW5SNWNHVWlPaUpGVVZWQlRGTWlMQ0p6ZVhOMFpXMGlPaUlpTENKMllXeDFaU0k2SW1abGJXRnNaU0o5WFgxZGZWMTlMQ0pwWkNJNkltRTJaakZqWTJZekxXVmlaakV0TkRJMFppMDVaRFk1TFRSbE5XUXhNelZtTWpNME1DSjkifQ=="}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks ``` +Creating a sample SQL task using curl: +```bash + curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"exliquid"},"body":"eyJwYXlsb2FkIjoiU0VMRUNUX1RBQkxFUyJ9"}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks + ``` + Creating a sample [Exporter](https://github.com/samply/exporter) "execute" task containing an Exporter query using curl: ```bash diff --git a/src/config.rs b/src/config.rs index 4a69eeb..ce28ab9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ -use std::path::PathBuf; use std::fmt; +use std::path::PathBuf; use beam_lib::AppId; use clap::Parser; @@ -10,7 +10,6 @@ use tracing::{debug, info, warn}; use crate::errors::FocusError; - #[derive(clap::ValueEnum, Clone, PartialEq, Debug)] pub enum Obfuscate { No, @@ -28,15 +27,14 @@ pub enum EndpointType { impl fmt::Display for EndpointType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - EndpointType::Blaze => write!(f, "blaze"), + EndpointType::Blaze => write!(f, "blaze"), EndpointType::Omop => write!(f, "omop"), - EndpointType::BlazeAndSql => write!(f, "blaze_sql"), + EndpointType::BlazeAndSql => write!(f, "blaze_and_sql"), EndpointType::Sql => write!(f, "sql"), } } } - pub(crate) static CONFIG: Lazy = Lazy::new(|| { debug!("Loading config"); Config::load().unwrap_or_else(|e| { @@ -132,7 +130,12 @@ struct CliArgs { rounding_step: usize, /// Projects for which the results are not to be obfuscated, separated by ; - #[clap(long, env, value_parser, default_value = "exliquid;dktk_supervisors;exporter;ehds2")] + #[clap( + long, + env, + value_parser, + default_value = "exliquid;dktk_supervisors;exporter;ehds2" + )] projects_no_obfuscation: String, /// Path to a file containing BASE64 encoded queries whose results are to be cached @@ -146,7 +149,7 @@ struct CliArgs { /// OMOP provider name #[clap(long, env, value_parser)] provider: Option, - + /// Base64 encoded OMOP provider icon #[clap(long, env, value_parser)] provider_icon: Option, @@ -155,10 +158,9 @@ struct CliArgs { #[clap(long, env, value_parser)] auth_header: Option, - /// Database connection string - #[clap(long, env, value_parser)] - db_connection_string: Option, - + /// Database connection string + #[clap(long, env, value_parser)] + db_connection_string: Option, } pub(crate) struct Config { @@ -284,7 +286,7 @@ pub fn prepare_reqwest_client(certs: &Vec) -> Result proxies.push( Proxy::all(v) - .map_err( FocusError::InvalidProxyConfig)? + .map_err(FocusError::InvalidProxyConfig)? .no_proxy(no_proxy.clone()), ), _ => (), diff --git a/src/db.rs b/src/db.rs index 4e77b6b..c82f557 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,10 +1,10 @@ -use sqlx::{postgres::PgPoolOptions, PgPool, postgres::PgRow}; +use crate::errors::FocusError; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; +use sqlx_pgrow_serde::SerMapPgRow; use std::collections::HashMap; -use tracing::{error, info}; -use crate::errors::FocusError; -use crate::util; +use tracing::{error, info, debug}; #[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct SqlQuery { @@ -15,12 +15,16 @@ include!(concat!(env!("OUT_DIR"), "/sql_replace_map.rs")); pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result { info!("Trying to establish a PostgreSQL connection pool"); - + let mut attempts = 0; let mut err: Option = None; - + while attempts < num_attempts { - info!("Attempt to connect to PostgreSQL {} of {}", attempts + 1, num_attempts); + info!( + "Attempt to connect to PostgreSQL {} of {}", + attempts + 1, + num_attempts + ); match PgPoolOptions::new() .max_connections(10) .connect(&pg_url) @@ -28,59 +32,97 @@ pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result

{ info!("PostgreSQL connection successfull"); - return Ok(pg_con_pool) - }, + return Ok(pg_con_pool); + } Err(e) => { - error!("Failed to connect to PostgreSQL. Attempt {} of {}: {}", attempts + 1, num_attempts, e); + error!( + "Failed to connect to PostgreSQL. Attempt {} of {}: {}", + attempts + 1, + num_attempts, + e + ); err = Some(FocusError::CannotConnectToDatabase(e.to_string())); } } attempts += 1; tokio::time::sleep(std::time::Duration::from_secs(1)).await; } - Err(err.unwrap_or_else(|| FocusError::CannotConnectToDatabase("Failed to connect to PostgreSQL".into()))) + Err(err.unwrap_or_else(|| { + FocusError::CannotConnectToDatabase("Failed to connect to PostgreSQL".into()) + })) } pub async fn healthcheck(pool: &PgPool) -> bool { - - let res = sqlx::query(include_str!("../resources/sql/SELECT_TABLES")) - .fetch_all(pool) - .await; - if let Ok(_) = res {true} else {false} + let res = run_query(pool, SQL_REPLACE_MAP.get("SELECT_TABLES").unwrap()).await; //this file exists, safe to unwrap + if let Ok(_) = res { + true + } else { + false + } } pub async fn run_query(pool: &PgPool, query: &str) -> Result, FocusError> { - sqlx::query(query) .fetch_all(pool) - .await.map_err( FocusError::ErrorExecutingQuery) + .await + .map_err(FocusError::ErrorExecutingQuery) } -pub async fn process_sql_task(pool: &PgPool, encoded: &str) -> Result, FocusError>{ - let decoded = util::base64_decode(encoded)?; - let key = String::from_utf8(decoded).map_err(FocusError::ErrorConvertingToString)?; - let key = key.as_str(); - let sql_query = SQL_REPLACE_MAP.get(&(key.clone())); - if sql_query.is_none(){ +pub async fn process_sql_task(pool: &PgPool, key: &str) -> Result, FocusError> { + debug!("Executing query with key = {}", &key); + let sql_query = SQL_REPLACE_MAP.get(&key); + if sql_query.is_none() { return Err(FocusError::QueryNotAllowed(key.into())); } - let query = sql_query.unwrap(); + let query = sql_query.unwrap(); + debug!("Executing query {}", &query); run_query(pool, query).await - } +pub fn serialize_rows(rows: Vec) -> Result { + let mut rows_json: Vec = vec![]; + + for row in rows { + let row = SerMapPgRow::from(row); + let row_json = serde_json::to_value(&row)?; + rows_json.push(row_json); + } + + Ok(json!(rows_json)) +} -#[cfg(test)] +#[cfg(test)] mod test { use super::*; #[tokio::test] #[ignore] //TODO mock DB async fn connect() { - let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1).await.unwrap(); - + let pool = + get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) + .await + .unwrap(); + assert!(healthcheck(&pool).await); } -} + #[tokio::test] + #[ignore] //TODO mock DB + async fn serialize() { + let pool = + get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) + .await + .unwrap(); + + let rows = run_query(&pool, SQL_REPLACE_MAP.get("SELECT_TABLES").unwrap()) + .await + .unwrap(); + + let rows_json = serialize_rows(rows).unwrap(); + + assert!(rows_json.is_array()); + + assert_ne!(rows_json[0]["hasindexes"], Value::Null); + } +} diff --git a/src/errors.rs b/src/errors.rs index 4f430b7..f3b8a2b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -64,6 +64,8 @@ pub enum FocusError { CannotConnectToDatabase(String), #[error("Error executing query: {0}")] ErrorExecutingQuery(sqlx::Error), + #[error("QueryResultBad: {0}")] + QueryResultBad(String), #[error("Error converting to string: {0}")] ErrorConvertingToString(std::string::FromUtf8Error), #[error("Query not allowed: {0}")] diff --git a/src/main.rs b/src/main.rs index 58f108a..7803615 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,13 +8,12 @@ mod errors; mod graceful_shutdown; mod logger; +mod db; +mod exporter; mod intermediate_rep; +mod projects; mod task_processing; mod util; -mod projects; -mod exporter; -mod db; - use base64::engine::general_purpose; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; @@ -25,13 +24,11 @@ use laplace_rs::ObfCache; use sqlx::PgPool; use tokio::sync::Mutex; - use crate::blaze::{parse_blaze_query_payload_ast, AstQuery}; use crate::config::EndpointType; use crate::util::{base64_decode, is_cql_tampered_with, obfuscate_counts_mr}; use crate::{config::CONFIG, errors::FocusError}; use blaze::CqlQuery; -use db::SqlQuery; use std::collections::HashMap; use std::ops::DerefMut; @@ -42,7 +39,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use std::{process::exit, time::Duration}; use serde::{Deserialize, Serialize}; -use tracing::{debug, error, warn, trace}; +use tracing::{debug, error, trace, warn}; // result cache type SearchQuery = String; @@ -125,9 +122,9 @@ async fn main_loop() -> ExitCode { let db_pool = if let Some(connection_string) = CONFIG.db_connection_string.clone() { match db::get_pg_connection_pool(&connection_string, 8).await { Err(e) => { - error!("Error connecting to database: {}, {}", connection_string, e); + error!("Error connecting to database: {}", e); return ExitCode::from(8); - }, + } Ok(pool) => Some(pool), } } else { @@ -137,7 +134,7 @@ async fn main_loop() -> ExitCode { EndpointType::Blaze => || blaze::check_availability().boxed(), EndpointType::Omop => || async { true }.boxed(), // TODO health check EndpointType::BlazeAndSql => || blaze::check_availability().boxed(), //TODO SQL health check - EndpointType::Sql => || async { true }.boxed(), // TODO health check + EndpointType::Sql => || async { true }.boxed(), // TODO health check }; let mut failures = 0; while !(beam::check_availability().await && endpoint_service_available().await) { @@ -152,10 +149,9 @@ async fn main_loop() -> ExitCode { tokio::time::sleep(Duration::from_secs(2)).await; warn!( "Retrying connection (attempt {}/{})", - failures, - CONFIG.retry_count + failures, CONFIG.retry_count ); - }; + } let report_cache = Arc::new(Mutex::new(ReportCache::new())); let obf_cache = Arc::new(Mutex::new(ObfCache { cache: Default::default(), @@ -164,7 +160,8 @@ async fn main_loop() -> ExitCode { let obf_cache = obf_cache.clone(); let report_cache = report_cache.clone(); process_task(db_pool.clone(), task, obf_cache, report_cache).boxed_local() - }).await; + }) + .await; ExitCode::FAILURE } @@ -178,7 +175,7 @@ async fn process_task( let metadata: Metadata = serde_json::from_value(task.metadata.clone()).unwrap_or(Metadata { project: "default_obfuscation".to_string(), - task_type: None + task_type: None, }); if metadata.project == "focus-healthcheck" { @@ -186,7 +183,7 @@ async fn process_task( CONFIG.beam_app_id_long.clone(), vec![task.from.clone()], task.id, - "healthy".into() + "healthy".into(), )); } if metadata.project == "exporter" { @@ -204,41 +201,107 @@ async fn process_task( Language::Cql(cql_query) => cql_query, Language::Ast(ast_query) => { generated_from_ast = true; - serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast(&ast_query.payload)?)?)? + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( + &ast_query.payload, + )?)?)? } }; - run_cql_query(task, &query, obf_cache, report_cache, metadata.project, generated_from_ast).await - - } else if CONFIG.endpoint_type == EndpointType::BlazeAndSql { + run_cql_query( + task, + &query, + obf_cache, + report_cache, + metadata.project, + generated_from_ast, + ) + .await + } else if CONFIG.endpoint_type == EndpointType::BlazeAndSql { let mut generated_from_ast: bool = false; let data = base64_decode(&task.body)?; - let query_maybe: Result = serde_json::from_slice(&(data.clone())); + let query_maybe: Result = + serde_json::from_slice(&(data.clone())); if let Ok(sql_query) = query_maybe { - if let Some(pool) = db_pool{ + if let Some(pool) = db_pool { let result = db::process_sql_task(&pool, &(sql_query.payload)).await; if let Ok(rows) = result { - + let rows_json = db::serialize_rows(rows)?; + trace!("result: {}", &rows_json); + Ok(beam::beam_result::succeeded( CONFIG.beam_app_id_long.clone(), vec![task.clone().from], task.id, - "".into(), + BASE64.encode(rows_json.to_string()), )) - } else {return Err(FocusError::CannotConnectToDatabase("SQL task but no connection String in config".into()));} - } - else { - return Err(FocusError::CannotConnectToDatabase("SQL task but no connection String in config".into())); + } else { + let error = result.err().unwrap(); + error!("Error executing query: {}", error); + return Err(error); + } + } else { + return Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), + )); } } else { - let query: CqlQuery = match serde_json::from_slice::(&data)? { Language::Cql(cql_query) => cql_query, Language::Ast(ast_query) => { generated_from_ast = true; - serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast(&ast_query.payload)?)?)? + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( + &ast_query.payload, + )?)?)? } }; - run_cql_query(task, &query, obf_cache, report_cache, metadata.project, generated_from_ast).await + run_cql_query( + task, + &query, + obf_cache, + report_cache, + metadata.project, + generated_from_ast, + ) + .await + } + } else if CONFIG.endpoint_type == EndpointType::Sql { + let data = base64_decode(&task.body)?; + let query_maybe: Result = serde_json::from_slice(&(data)); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool { + let result = db::process_sql_task(&pool, &(sql_query.payload)).await; + if let Ok(rows) = result { + let rows_json = db::serialize_rows(rows)?; + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.clone().from], + task.id, + BASE64.encode(rows_json.to_string()), + )) + } else { + return Err(FocusError::QueryResultBad( + "Query executed but result not readable".into(), + )); + } + } else { + return Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), + )); + } + } else { + warn!( + "Wrong type of query for an SQL only store: {}, {:?}", + CONFIG.endpoint_type, data + ); + Ok(beam::beam_result::perm_failed( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + format!( + "Wrong type of query for an SQL only store: {}, {:?}", + CONFIG.endpoint_type, data + ), + )) } } else if CONFIG.endpoint_type == EndpointType::Omop { let decoded = util::base64_decode(&task.body)?; @@ -248,8 +311,7 @@ async fn process_task( let query_decoded = general_purpose::STANDARD .decode(intermediate_rep_query.query) .map_err(FocusError::DecodeError)?; - let ast: ast::Ast = - serde_json::from_slice(&query_decoded)?; + let ast: ast::Ast = serde_json::from_slice(&query_decoded)?; Ok(run_intermediate_rep_query(task, ast).await?) } else { @@ -275,7 +337,7 @@ async fn run_cql_query( obf_cache: Arc>, report_cache: Arc>, project: String, - generated_from_ast: bool + generated_from_ast: bool, ) -> Result { let encoded_query = query.lib["content"][0]["data"] @@ -310,9 +372,8 @@ async fn run_cql_query( let cql_result_new = match report_from_cache { Some(some_report_from_cache) => some_report_from_cache.to_string(), None => { - let query = - if generated_from_ast { - query.clone() + let query = if generated_from_ast { + query.clone() } else { replace_cql_library(query.clone())? }; @@ -466,7 +527,6 @@ fn beam_result(task: BeamTask, measure_report: String) -> Result Date: Wed, 17 Jul 2024 19:36:51 +0200 Subject: [PATCH 03/18] style --- src/db.rs | 10 +++------- src/errors.rs | 2 -- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/db.rs b/src/db.rs index c82f557..1025597 100644 --- a/src/db.rs +++ b/src/db.rs @@ -27,7 +27,7 @@ pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result

{ @@ -54,11 +54,7 @@ pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result

bool { let res = run_query(pool, SQL_REPLACE_MAP.get("SELECT_TABLES").unwrap()).await; //this file exists, safe to unwrap - if let Ok(_) = res { - true - } else { - false - } + res.is_ok() } pub async fn run_query(pool: &PgPool, query: &str) -> Result, FocusError> { @@ -98,7 +94,7 @@ mod test { #[tokio::test] #[ignore] //TODO mock DB - async fn connect() { + async fn connect_healthcheck() { let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) .await diff --git a/src/errors.rs b/src/errors.rs index f3b8a2b..d01059d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -66,8 +66,6 @@ pub enum FocusError { ErrorExecutingQuery(sqlx::Error), #[error("QueryResultBad: {0}")] QueryResultBad(String), - #[error("Error converting to string: {0}")] - ErrorConvertingToString(std::string::FromUtf8Error), #[error("Query not allowed: {0}")] QueryNotAllowed(String), } From 7144502273f22d9511a8df049d522ef108d9c2b9 Mon Sep 17 00:00:00 2001 From: Enola Knezevic <115070135+enola-dkfz@users.noreply.github.com> Date: Thu, 18 Jul 2024 14:07:28 +0200 Subject: [PATCH 04/18] Update README.md Co-authored-by: Tobias Kussel --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e004c39..d8e56be 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ Creating a sample task containing an abstract syntax tree (AST) query using curl curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"bbmri"},"body":"eyJsYW5nIjoiYXN0IiwicGF5bG9hZCI6ImV5SmhjM1FpT25zaWIzQmxjbUZ1WkNJNklrOVNJaXdpWTJocGJHUnlaVzRpT2x0N0ltOXdaWEpoYm1RaU9pSkJUa1FpTENKamFHbHNaSEpsYmlJNlczc2liM0JsY21GdVpDSTZJazlTSWl3aVkyaHBiR1J5Wlc0aU9sdDdJbXRsZVNJNkltZGxibVJsY2lJc0luUjVjR1VpT2lKRlVWVkJURk1pTENKemVYTjBaVzBpT2lJaUxDSjJZV3gxWlNJNkltMWhiR1VpZlN4N0ltdGxlU0k2SW1kbGJtUmxjaUlzSW5SNWNHVWlPaUpGVVZWQlRGTWlMQ0p6ZVhOMFpXMGlPaUlpTENKMllXeDFaU0k2SW1abGJXRnNaU0o5WFgxZGZWMTlMQ0pwWkNJNkltRTJaakZqWTJZekxXVmlaakV0TkRJMFppMDVaRFk1TFRSbE5XUXhNelZtTWpNME1DSjkifQ=="}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks ``` -Creating a sample SQL task using curl: +Creating a sample SQL task for a `SELECT_TABLES` query using curl: ```bash curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"exliquid"},"body":"eyJwYXlsb2FkIjoiU0VMRUNUX1RBQkxFUyJ9"}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks ``` From 08ce472621ca28cd4db2a40f5b0eee079c1762dd Mon Sep 17 00:00:00 2001 From: Enola Knezevic <115070135+enola-dkfz@users.noreply.github.com> Date: Thu, 18 Jul 2024 14:53:10 +0200 Subject: [PATCH 05/18] Update src/db.rs Co-authored-by: Jan <59206115+Threated@users.noreply.github.com> --- src/db.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/db.rs b/src/db.rs index 1025597..56f875a 100644 --- a/src/db.rs +++ b/src/db.rs @@ -67,10 +67,9 @@ pub async fn run_query(pool: &PgPool, query: &str) -> Result, FocusEr pub async fn process_sql_task(pool: &PgPool, key: &str) -> Result, FocusError> { debug!("Executing query with key = {}", &key); let sql_query = SQL_REPLACE_MAP.get(&key); - if sql_query.is_none() { + let Some(query) = sql_query else { return Err(FocusError::QueryNotAllowed(key.into())); - } - let query = sql_query.unwrap(); + }; debug!("Executing query {}", &query); run_query(pool, query).await From a27293c4d78867d31e4845a811474fd0a6193ddc Mon Sep 17 00:00:00 2001 From: Enola Knezevic <115070135+enola-dkfz@users.noreply.github.com> Date: Thu, 18 Jul 2024 14:53:55 +0200 Subject: [PATCH 06/18] Update src/main.rs Co-authored-by: Jan <59206115+Threated@users.noreply.github.com> --- src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 7803615..5a07d95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -229,7 +229,7 @@ async fn process_task( Ok(beam::beam_result::succeeded( CONFIG.beam_app_id_long.clone(), - vec![task.clone().from], + vec![task.from.clone()], task.id, BASE64.encode(rows_json.to_string()), )) From 5c0ed700356a4d9a033af3dfc757ded8b664d989 Mon Sep 17 00:00:00 2001 From: Enola Knezevic <115070135+enola-dkfz@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:38:55 +0200 Subject: [PATCH 07/18] Update src/db.rs Co-authored-by: Jan <59206115+Threated@users.noreply.github.com> --- src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/db.rs b/src/db.rs index 56f875a..334eb69 100644 --- a/src/db.rs +++ b/src/db.rs @@ -76,7 +76,7 @@ pub async fn process_sql_task(pool: &PgPool, key: &str) -> Result, Fo } pub fn serialize_rows(rows: Vec) -> Result { - let mut rows_json: Vec = vec![]; + let mut rows_json: Vec = Vec::with_capacity(rows.len()); for row in rows { let row = SerMapPgRow::from(row); From 9fd8a9d105c3e04e7b4c9db25d4d90efc3ceece9 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 15:39:52 +0200 Subject: [PATCH 08/18] requested changes --- src/db.rs | 12 +++++------- src/main.rs | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/db.rs b/src/db.rs index 1025597..32a3f5b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -4,7 +4,7 @@ use serde_json::{json, Value}; use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; use sqlx_pgrow_serde::SerMapPgRow; use std::collections::HashMap; -use tracing::{error, info, debug}; +use tracing::{warn, info, debug}; #[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct SqlQuery { @@ -35,7 +35,7 @@ pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result

{ - error!( + warn!( "Failed to connect to PostgreSQL. Attempt {} of {}: {}", attempts + 1, num_attempts, @@ -47,9 +47,7 @@ pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result

bool { @@ -93,7 +91,7 @@ mod test { use super::*; #[tokio::test] - #[ignore] //TODO mock DB + //#[ignore] //TODO mock DB async fn connect_healthcheck() { let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) @@ -104,7 +102,7 @@ mod test { } #[tokio::test] - #[ignore] //TODO mock DB + //#[ignore] //TODO mock DB async fn serialize() { let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) diff --git a/src/main.rs b/src/main.rs index 7803615..214de68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -231,7 +231,7 @@ async fn process_task( CONFIG.beam_app_id_long.clone(), vec![task.clone().from], task.id, - BASE64.encode(rows_json.to_string()), + BASE64.encode(serde_json::to_string(&rows_json)?), )) } else { let error = result.err().unwrap(); @@ -276,7 +276,7 @@ async fn process_task( CONFIG.beam_app_id_long.clone(), vec![task.clone().from], task.id, - BASE64.encode(rows_json.to_string()), + BASE64.encode(serde_json::to_string(&rows_json)?), )) } else { return Err(FocusError::QueryResultBad( From 410c5b1b7e2849af71d512c1a1785dd270e0b70c Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 15:47:17 +0200 Subject: [PATCH 09/18] ignoring tests --- src/db.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/db.rs b/src/db.rs index c5ec19d..843fea3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -90,7 +90,7 @@ mod test { use super::*; #[tokio::test] - //#[ignore] //TODO mock DB + #[ignore] //TODO mock DB async fn connect_healthcheck() { let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) @@ -101,7 +101,7 @@ mod test { } #[tokio::test] - //#[ignore] //TODO mock DB + #[ignore] //TODO mock DB async fn serialize() { let pool = get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) From 69eb81b8afbff17e57b76731a150b9175f892538 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 15:51:09 +0200 Subject: [PATCH 10/18] requested changes --- src/db.rs | 2 +- src/main.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/db.rs b/src/db.rs index 843fea3..2e474ec 100644 --- a/src/db.rs +++ b/src/db.rs @@ -82,7 +82,7 @@ pub fn serialize_rows(rows: Vec) -> Result { rows_json.push(row_json); } - Ok(json!(rows_json)) + Ok(Value::Array(rows_json)) } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index e6aee42..55b988a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -133,8 +133,8 @@ async fn main_loop() -> ExitCode { let endpoint_service_available: fn() -> BoxFuture<'static, bool> = match CONFIG.endpoint_type { EndpointType::Blaze => || blaze::check_availability().boxed(), EndpointType::Omop => || async { true }.boxed(), // TODO health check - EndpointType::BlazeAndSql => || blaze::check_availability().boxed(), //TODO SQL health check - EndpointType::Sql => || async { true }.boxed(), // TODO health check + EndpointType::BlazeAndSql => || blaze::check_availability().boxed(), + EndpointType::Sql => || async { true }.boxed(), }; let mut failures = 0; while !(beam::check_availability().await && endpoint_service_available().await) { From fe0cf673b9913810bd37401450bde11e4644ba58 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 15:53:26 +0200 Subject: [PATCH 11/18] renamed db to postgres --- src/config.rs | 6 +++--- src/main.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index ce28ab9..1498943 100644 --- a/src/config.rs +++ b/src/config.rs @@ -160,7 +160,7 @@ struct CliArgs { /// Database connection string #[clap(long, env, value_parser)] - db_connection_string: Option, + postgres_connection_string: Option, } pub(crate) struct Config { @@ -188,7 +188,7 @@ pub(crate) struct Config { pub provider: Option, pub provider_icon: Option, pub auth_header: Option, - pub db_connection_string: Option, + pub postgres_connection_string: Option, } impl Config { @@ -230,7 +230,7 @@ impl Config { provider: cli_args.provider, provider_icon: cli_args.provider_icon, auth_header: cli_args.auth_header, - db_connection_string: cli_args.db_connection_string, + postgres_connection_string: cli_args.postgres_connection_string, client, }; Ok(config) diff --git a/src/main.rs b/src/main.rs index 55b988a..d99cb58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -119,7 +119,7 @@ pub async fn main() -> ExitCode { } async fn main_loop() -> ExitCode { - let db_pool = if let Some(connection_string) = CONFIG.db_connection_string.clone() { + let db_pool = if let Some(connection_string) = CONFIG.postgres_connection_string.clone() { match db::get_pg_connection_pool(&connection_string, 8).await { Err(e) => { error!("Error connecting to database: {}", e); From 3ebdbc28a79dd45e335918abd21345c1d950c177 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 15:54:27 +0200 Subject: [PATCH 12/18] renamed db to postgres in Readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d8e56be..1718f1a 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ QUERIES_TO_CACHE = "queries_to_cache.conf" # The path to a file containing base6 PROVIDER = "name" #EUCAIM provider name PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded EUCAIM provider icon AUTH_HEADER = "ApiKey XXXX" #Authorization header -DB_CONNECTION_STRING = "postgresql://postgres:Test.123@localhost:5432/postgres" # Database connection string +POSTGRES_CONNECTION_STRING = "postgresql://postgres:Test.123@localhost:5432/postgres" # Postgres connection string ``` Obfuscating zero counts is by default switched off. To enable obfuscating zero counts, set the env. variable `OBFUSCATE_ZERO`. From c71de6d6327be045de2da93b25a3d5ffebcf4d33 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 16:06:27 +0200 Subject: [PATCH 13/18] up the error --- src/main.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/main.rs b/src/main.rs index d99cb58..3cd2c57 100644 --- a/src/main.rs +++ b/src/main.rs @@ -222,8 +222,7 @@ async fn process_task( serde_json::from_slice(&(data.clone())); if let Ok(sql_query) = query_maybe { if let Some(pool) = db_pool { - let result = db::process_sql_task(&pool, &(sql_query.payload)).await; - if let Ok(rows) = result { + let rows = db::process_sql_task(&pool, &(sql_query.payload)).await?; let rows_json = db::serialize_rows(rows)?; trace!("result: {}", &rows_json); @@ -233,11 +232,6 @@ async fn process_task( task.id, BASE64.encode(serde_json::to_string(&rows_json)?), )) - } else { - let error = result.err().unwrap(); - error!("Error executing query: {}", error); - return Err(error); - } } else { return Err(FocusError::CannotConnectToDatabase( "SQL task but no connection String in config".into(), From 3dce3bf91e292dadfe5652e217a7cfa0002f112a Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 16:24:15 +0200 Subject: [PATCH 14/18] match endpoint type --- src/main.rs | 197 +++++++++++++++++++++++++--------------------------- 1 file changed, 94 insertions(+), 103 deletions(-) diff --git a/src/main.rs b/src/main.rs index 3cd2c57..f27e9f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -194,50 +194,10 @@ async fn process_task( return run_exporter_query(task, body, task_type).await; } - if CONFIG.endpoint_type == EndpointType::Blaze { - let mut generated_from_ast: bool = false; - let data = base64_decode(&task.body)?; - let query: CqlQuery = match serde_json::from_slice::(&data)? { - Language::Cql(cql_query) => cql_query, - Language::Ast(ast_query) => { - generated_from_ast = true; - serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( - &ast_query.payload, - )?)?)? - } - }; - run_cql_query( - task, - &query, - obf_cache, - report_cache, - metadata.project, - generated_from_ast, - ) - .await - } else if CONFIG.endpoint_type == EndpointType::BlazeAndSql { - let mut generated_from_ast: bool = false; - let data = base64_decode(&task.body)?; - let query_maybe: Result = - serde_json::from_slice(&(data.clone())); - if let Ok(sql_query) = query_maybe { - if let Some(pool) = db_pool { - let rows = db::process_sql_task(&pool, &(sql_query.payload)).await?; - let rows_json = db::serialize_rows(rows)?; - trace!("result: {}", &rows_json); - - Ok(beam::beam_result::succeeded( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - BASE64.encode(serde_json::to_string(&rows_json)?), - )) - } else { - return Err(FocusError::CannotConnectToDatabase( - "SQL task but no connection String in config".into(), - )); - } - } else { + match CONFIG.endpoint_type { + EndpointType::Blaze => { + let mut generated_from_ast: bool = false; + let data = base64_decode(&task.body)?; let query: CqlQuery = match serde_json::from_slice::(&data)? { Language::Cql(cql_query) => cql_query, Language::Ast(ast_query) => { @@ -256,72 +216,103 @@ async fn process_task( generated_from_ast, ) .await - } - } else if CONFIG.endpoint_type == EndpointType::Sql { - let data = base64_decode(&task.body)?; - let query_maybe: Result = serde_json::from_slice(&(data)); - if let Ok(sql_query) = query_maybe { - if let Some(pool) = db_pool { - let result = db::process_sql_task(&pool, &(sql_query.payload)).await; - if let Ok(rows) = result { - let rows_json = db::serialize_rows(rows)?; - - Ok(beam::beam_result::succeeded( - CONFIG.beam_app_id_long.clone(), - vec![task.clone().from], - task.id, - BASE64.encode(serde_json::to_string(&rows_json)?), - )) + }, + EndpointType::BlazeAndSql => { + let mut generated_from_ast: bool = false; + let data = base64_decode(&task.body)?; + let query_maybe: Result = + serde_json::from_slice(&(data.clone())); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool { + let rows = db::process_sql_task(&pool, &(sql_query.payload)).await?; + let rows_json = db::serialize_rows(rows)?; + trace!("result: {}", &rows_json); + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + BASE64.encode(serde_json::to_string(&rows_json)?), + )) } else { - return Err(FocusError::QueryResultBad( - "Query executed but result not readable".into(), + return Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), )); } } else { - return Err(FocusError::CannotConnectToDatabase( - "SQL task but no connection String in config".into(), - )); + let query: CqlQuery = match serde_json::from_slice::(&data)? { + Language::Cql(cql_query) => cql_query, + Language::Ast(ast_query) => { + generated_from_ast = true; + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( + &ast_query.payload, + )?)?)? + } + }; + run_cql_query( + task, + &query, + obf_cache, + report_cache, + metadata.project, + generated_from_ast, + ) + .await } - } else { - warn!( - "Wrong type of query for an SQL only store: {}, {:?}", - CONFIG.endpoint_type, data - ); - Ok(beam::beam_result::perm_failed( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - format!( + }, + EndpointType::Sql => { + let data = base64_decode(&task.body)?; + let query_maybe: Result = serde_json::from_slice(&(data)); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool { + let result = db::process_sql_task(&pool, &(sql_query.payload)).await; + if let Ok(rows) = result { + let rows_json = db::serialize_rows(rows)?; + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.clone().from], + task.id, + BASE64.encode(serde_json::to_string(&rows_json)?), + )) + } else { + return Err(FocusError::QueryResultBad( + "Query executed but result not readable".into(), + )); + } + } else { + return Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), + )); + } + } else { + warn!( "Wrong type of query for an SQL only store: {}, {:?}", CONFIG.endpoint_type, data - ), - )) - } - } else if CONFIG.endpoint_type == EndpointType::Omop { - let decoded = util::base64_decode(&task.body)?; - let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = - serde_json::from_slice(&decoded)?; - //TODO check that the language is ast - let query_decoded = general_purpose::STANDARD - .decode(intermediate_rep_query.query) - .map_err(FocusError::DecodeError)?; - let ast: ast::Ast = serde_json::from_slice(&query_decoded)?; - - Ok(run_intermediate_rep_query(task, ast).await?) - } else { - warn!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ); - Ok(beam::beam_result::perm_failed( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - format!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ), - )) + ); + Ok(beam::beam_result::perm_failed( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + format!( + "Wrong type of query for an SQL only store: {}, {:?}", + CONFIG.endpoint_type, data + ), + )) + } + }, + EndpointType::Omop => { + let decoded = util::base64_decode(&task.body)?; + let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = + serde_json::from_slice(&decoded)?; + //TODO check that the language is ast + let query_decoded = general_purpose::STANDARD + .decode(intermediate_rep_query.query) + .map_err(FocusError::DecodeError)?; + let ast: ast::Ast = serde_json::from_slice(&query_decoded)?; + + Ok(run_intermediate_rep_query(task, ast).await?) + } } } From 8678af0c33badef5211a627885bd19bdcbaa5e6b Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Fri, 19 Jul 2024 16:29:47 +0200 Subject: [PATCH 15/18] process task first param task --- src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index f27e9f0..2833788 100644 --- a/src/main.rs +++ b/src/main.rs @@ -159,17 +159,17 @@ async fn main_loop() -> ExitCode { task_processing::process_tasks(move |task| { let obf_cache = obf_cache.clone(); let report_cache = report_cache.clone(); - process_task(db_pool.clone(), task, obf_cache, report_cache).boxed_local() + process_task(task, obf_cache, report_cache, db_pool.clone()).boxed_local() }) .await; ExitCode::FAILURE } async fn process_task( - db_pool: Option, task: &BeamTask, obf_cache: Arc>, report_cache: Arc>, + db_pool: Option, ) -> Result { debug!("Processing task {}", task.id); From 7899eb105192ac07b352cd4db51cb757bc06e445 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Mon, 22 Jul 2024 14:55:48 +0200 Subject: [PATCH 16/18] pg connection tokio_retry exp backoff + jitter --- Cargo.toml | 2 ++ src/db.rs | 49 +++++++++++++++++++------------------------------ 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dd59e6a..d801f60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ rand = { default-features = false, version = "0.8.5" } futures-util = { version = "0.3", default-features = false, features = ["std"] } sqlx = { version = "0.7.4", features = [ "runtime-tokio", "postgres", "macros", "chrono"] } sqlx-pgrow-serde = "0.2.0" +tokio-retry = "0.3" # Logging tracing = { version = "0.1.37", default_features = false } @@ -34,6 +35,7 @@ once_cell = "1.18" # Command Line Interface clap = { version = "4", default_features = false, features = ["std", "env", "derive", "help", "color"] } + [features] default = [] bbmri = [] diff --git a/src/db.rs b/src/db.rs index 2e474ec..82642a3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,10 +1,13 @@ use crate::errors::FocusError; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::Value; use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; use sqlx_pgrow_serde::SerMapPgRow; use std::collections::HashMap; use tracing::{warn, info, debug}; +use tokio_retry::strategy::{ExponentialBackoff, jitter}; +use tokio_retry::Retry; + #[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct SqlQuery { @@ -13,43 +16,29 @@ pub struct SqlQuery { include!(concat!(env!("OUT_DIR"), "/sql_replace_map.rs")); -pub async fn get_pg_connection_pool(pg_url: &str, num_attempts: u32) -> Result { +pub async fn get_pg_connection_pool(pg_url: &str, max_attempts: u32) -> Result { info!("Trying to establish a PostgreSQL connection pool"); - let mut attempts = 0; - let mut err: Option = None; + let retry_strategy = ExponentialBackoff::from_millis(1000) + .map(jitter) + .take(max_attempts as usize); - while attempts < num_attempts { - info!( - "Attempt to connect to PostgreSQL {} of {}", - attempts + 1, - num_attempts - ); - match PgPoolOptions::new() + let result = Retry::spawn(retry_strategy, || async { + info!("Attempting to connect to PostgreSQL"); + PgPoolOptions::new() .max_connections(10) .connect(pg_url) .await - { - Ok(pg_con_pool) => { - info!("PostgreSQL connection successfull"); - return Ok(pg_con_pool); - } - Err(e) => { - warn!( - "Failed to connect to PostgreSQL. Attempt {} of {}: {}", - attempts + 1, - num_attempts, - e - ); - err = Some(FocusError::CannotConnectToDatabase(e.to_string())); - } - } - attempts += 1; - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - } - Err(err.unwrap()) + .map_err(|e| { + warn!("Failed to connect to PostgreSQL: {}", e); + FocusError::CannotConnectToDatabase(e.to_string()) + }) + }).await; + + result } + pub async fn healthcheck(pool: &PgPool) -> bool { let res = run_query(pool, SQL_REPLACE_MAP.get("SELECT_TABLES").unwrap()).await; //this file exists, safe to unwrap res.is_ok() From f55cfbb902137d7be94e92218bd55ebff46ec768 Mon Sep 17 00:00:00 2001 From: Enola Knezevic Date: Mon, 22 Jul 2024 15:29:41 +0200 Subject: [PATCH 17/18] requested by clippy --- src/db.rs | 15 ++++++--------- src/main.rs | 12 ++++++------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/db.rs b/src/db.rs index 82642a3..6af683c 100644 --- a/src/db.rs +++ b/src/db.rs @@ -4,10 +4,9 @@ use serde_json::Value; use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; use sqlx_pgrow_serde::SerMapPgRow; use std::collections::HashMap; -use tracing::{warn, info, debug}; -use tokio_retry::strategy::{ExponentialBackoff, jitter}; +use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::Retry; - +use tracing::{debug, info, warn}; #[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct SqlQuery { @@ -20,10 +19,10 @@ pub async fn get_pg_connection_pool(pg_url: &str, max_attempts: u32) -> Result

Result

bool { let res = run_query(pool, SQL_REPLACE_MAP.get("SELECT_TABLES").unwrap()).await; //this file exists, safe to unwrap res.is_ok() diff --git a/src/main.rs b/src/main.rs index 2833788..75528bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -235,9 +235,9 @@ async fn process_task( BASE64.encode(serde_json::to_string(&rows_json)?), )) } else { - return Err(FocusError::CannotConnectToDatabase( + Err(FocusError::CannotConnectToDatabase( "SQL task but no connection String in config".into(), - )); + )) } } else { let query: CqlQuery = match serde_json::from_slice::(&data)? { @@ -276,14 +276,14 @@ async fn process_task( BASE64.encode(serde_json::to_string(&rows_json)?), )) } else { - return Err(FocusError::QueryResultBad( + Err(FocusError::QueryResultBad( "Query executed but result not readable".into(), - )); + )) } } else { - return Err(FocusError::CannotConnectToDatabase( + Err(FocusError::CannotConnectToDatabase( "SQL task but no connection String in config".into(), - )); + )) } } else { warn!( From dff9045c35335c7c0a1d2b8fdde0603f10bb214f Mon Sep 17 00:00:00 2001 From: janskiba Date: Mon, 22 Jul 2024 13:30:21 +0000 Subject: [PATCH 18/18] chore: use `tryhard` as retry crate --- Cargo.toml | 2 +- src/db.rs | 17 ++++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d801f60..1d109df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ rand = { default-features = false, version = "0.8.5" } futures-util = { version = "0.3", default-features = false, features = ["std"] } sqlx = { version = "0.7.4", features = [ "runtime-tokio", "postgres", "macros", "chrono"] } sqlx-pgrow-serde = "0.2.0" -tokio-retry = "0.3" +tryhard = "0.5" # Logging tracing = { version = "0.1.37", default_features = false } diff --git a/src/db.rs b/src/db.rs index 82642a3..80c2416 100644 --- a/src/db.rs +++ b/src/db.rs @@ -3,10 +3,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; use sqlx_pgrow_serde::SerMapPgRow; -use std::collections::HashMap; +use std::{collections::HashMap, time::Duration}; use tracing::{warn, info, debug}; -use tokio_retry::strategy::{ExponentialBackoff, jitter}; -use tokio_retry::Retry; #[derive(Serialize, Deserialize, Debug, Default, Clone)] @@ -19,11 +17,7 @@ include!(concat!(env!("OUT_DIR"), "/sql_replace_map.rs")); pub async fn get_pg_connection_pool(pg_url: &str, max_attempts: u32) -> Result { info!("Trying to establish a PostgreSQL connection pool"); - let retry_strategy = ExponentialBackoff::from_millis(1000) - .map(jitter) - .take(max_attempts as usize); - - let result = Retry::spawn(retry_strategy, || async { + tryhard::retry_fn(|| async { info!("Attempting to connect to PostgreSQL"); PgPoolOptions::new() .max_connections(10) @@ -33,9 +27,10 @@ pub async fn get_pg_connection_pool(pg_url: &str, max_attempts: u32) -> Result