diff --git a/Cargo.lock b/Cargo.lock index 9e5d48f..9d2ad04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -654,7 +654,7 @@ checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" [[package]] name = "db-pool" -version = "0.1.4" +version = "0.2.0" dependencies = [ "async-graphql", "async-graphql-poem", diff --git a/Cargo.toml b/Cargo.toml index 2f7f51c..2c3bd3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "db-pool" -version = "0.1.4" +version = "0.2.0" edition = "2021" description = "A thread-safe database pool for running database-tied integration tests in parallel" license = "MIT" diff --git a/book/Cargo.lock b/book/Cargo.lock index ac0f8cc..52718a1 100644 --- a/book/Cargo.lock +++ b/book/Cargo.lock @@ -184,7 +184,7 @@ dependencies = [ [[package]] name = "book" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bb8", "bb8-postgres", @@ -446,7 +446,7 @@ dependencies = [ [[package]] name = "db-pool" -version = "0.1.4" +version = "0.2.0" dependencies = [ "async-trait", "bb8", diff --git a/book/Cargo.toml b/book/Cargo.toml index bf357ac..6328ca3 100644 --- a/book/Cargo.toml +++ b/book/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "book" -version = "0.1.0" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/book/src/tutorials/async/05.rs b/book/src/tutorials/async/05.rs index 1139ddf..b334bc6 100644 --- a/book/src/tutorials/async/05.rs +++ b/book/src/tutorials/async/05.rs @@ -7,14 +7,12 @@ mod tests { use bb8::Pool; use db_pool::{ r#async::{ - // import connection pool - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, DieselBb8, - // import reusable object wrapper - Reusable, + // import reusable connection pool + ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -25,7 +23,7 @@ mod tests { // change return type async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, DieselAsyncPostgresBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -60,6 +58,6 @@ mod tests { .await; // pull connection pool - db_pool.pull().await + db_pool.pull_immutable().await } } diff --git a/book/src/tutorials/async/06.rs b/book/src/tutorials/async/06.rs index 168a6cf..9e7d914 100644 --- a/book/src/tutorials/async/06.rs +++ b/book/src/tutorials/async/06.rs @@ -6,8 +6,8 @@ mod tests { use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, - DieselBb8, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, DieselBb8, + ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -19,7 +19,7 @@ mod tests { use tokio::sync::OnceCell; async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, DieselAsyncPostgresBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -53,7 +53,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } // add test case diff --git a/book/src/tutorials/async/07.rs b/book/src/tutorials/async/07.rs index 0d74a13..30d07b2 100644 --- a/book/src/tutorials/async/07.rs +++ b/book/src/tutorials/async/07.rs @@ -8,8 +8,8 @@ mod tests { use bb8::Pool; use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, - DieselBb8, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, DieselBb8, + ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -21,7 +21,7 @@ mod tests { use tokio_shared_rt::test; async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, DieselAsyncPostgresBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -55,7 +55,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/book/src/tutorials/sync/05.rs b/book/src/tutorials/sync/05.rs index 1875e95..719cd2c 100644 --- a/book/src/tutorials/sync/05.rs +++ b/book/src/tutorials/sync/05.rs @@ -8,13 +8,11 @@ mod tests { use db_pool::{ sync::{ - // import connection pool - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, - // import reusable object wrapper - Reusable, + // import reusable connection pool + ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -23,7 +21,7 @@ mod tests { use r2d2::Pool; // change return type - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, DieselPostgresBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -47,6 +45,6 @@ mod tests { }); // pull connection pool - db_pool.pull() + db_pool.pull_immutable() } } diff --git a/book/src/tutorials/sync/06.rs b/book/src/tutorials/sync/06.rs index a0aaacc..60bbd18 100644 --- a/book/src/tutorials/sync/06.rs +++ b/book/src/tutorials/sync/06.rs @@ -8,7 +8,7 @@ mod tests { use db_pool::{ sync::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -17,7 +17,7 @@ mod tests { use dotenvy::dotenv; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, DieselPostgresBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -40,7 +40,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } // add test case diff --git a/book/src/tutorials/sync/07.rs b/book/src/tutorials/sync/07.rs index ac331f5..6c464e1 100644 --- a/book/src/tutorials/sync/07.rs +++ b/book/src/tutorials/sync/07.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ sync::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -14,7 +14,7 @@ mod tests { use dotenvy::dotenv; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, DieselPostgresBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -37,7 +37,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } fn test() { diff --git a/examples/async-graphql/main.rs b/examples/async-graphql/main.rs index a81afcd..595d22a 100644 --- a/examples/async-graphql/main.rs +++ b/examples/async-graphql/main.rs @@ -189,7 +189,7 @@ mod tests { }) .await; - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; PoolWrapper::ReusablePool(conn_pool) } diff --git a/examples/diesel_async_mysql.rs b/examples/diesel_async_mysql.rs index 46747c7..9145185 100644 --- a/examples/diesel_async_mysql.rs +++ b/examples/diesel_async_mysql.rs @@ -7,8 +7,8 @@ mod tests { use bb8::Pool; use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselAsyncMySQLBackend, - DieselBb8, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselAsyncMySQLBackend, DieselBb8, + ReusableConnectionPool, }, PrivilegedMySQLConfig, }; @@ -19,7 +19,7 @@ mod tests { use tokio_shared_rt::test; async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, DieselAsyncMySQLBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -51,7 +51,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/diesel_async_postgres.rs b/examples/diesel_async_postgres.rs index 18eb7fb..9d2b718 100644 --- a/examples/diesel_async_postgres.rs +++ b/examples/diesel_async_postgres.rs @@ -7,8 +7,8 @@ mod tests { use bb8::Pool; use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, - DieselBb8, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, DieselBb8, + ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -19,7 +19,7 @@ mod tests { use tokio_shared_rt::test; async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, DieselAsyncPostgresBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -53,7 +53,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/diesel_mysql.rs b/examples/diesel_mysql.rs index de6500d..8c125cc 100644 --- a/examples/diesel_mysql.rs +++ b/examples/diesel_mysql.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ sync::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselMySQLBackend, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselMySQLBackend, ReusableConnectionPool, }, PrivilegedMySQLConfig, }; @@ -14,7 +14,7 @@ mod tests { use dotenvy::dotenv; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, DieselMySQLBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -37,7 +37,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } fn test() { diff --git a/examples/diesel_postgres.rs b/examples/diesel_postgres.rs index af5817c..f802d89 100644 --- a/examples/diesel_postgres.rs +++ b/examples/diesel_postgres.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ sync::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, Reusable, + DatabasePool, DatabasePoolBuilderTrait, DieselPostgresBackend, ReusableConnectionPool, }, PrivilegedPostgresConfig, }; @@ -14,7 +14,7 @@ mod tests { use dotenvy::dotenv; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, DieselPostgresBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -37,7 +37,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } fn test() { diff --git a/examples/mysql.rs b/examples/mysql.rs index a9fed5a..5eed700 100644 --- a/examples/mysql.rs +++ b/examples/mysql.rs @@ -5,14 +5,14 @@ mod tests { use std::sync::OnceLock; use db_pool::{ - sync::{ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, MySQLBackend, Reusable}, + sync::{DatabasePool, DatabasePoolBuilderTrait, MySQLBackend, ReusableConnectionPool}, PrivilegedMySQLConfig, }; use dotenvy::dotenv; use mysql::{params, prelude::Queryable}; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, MySQLBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -36,7 +36,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } fn test() { diff --git a/examples/postgres.rs b/examples/postgres.rs index 78e9c00..8026c42 100644 --- a/examples/postgres.rs +++ b/examples/postgres.rs @@ -5,13 +5,13 @@ mod tests { use std::sync::OnceLock; use db_pool::{ - sync::{ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, PostgresBackend, Reusable}, + sync::{DatabasePool, DatabasePoolBuilderTrait, PostgresBackend, ReusableConnectionPool}, PrivilegedPostgresConfig, }; use dotenvy::dotenv; use r2d2::Pool; - fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + fn get_connection_pool() -> ReusableConnectionPool<'static, PostgresBackend> { static POOL: OnceLock> = OnceLock::new(); let db_pool = POOL.get_or_init(|| { @@ -36,7 +36,7 @@ mod tests { backend.create_database_pool().unwrap() }); - db_pool.pull() + db_pool.pull_immutable() } fn test() { diff --git a/examples/sea_orm_mysql.rs b/examples/sea_orm_mysql.rs index 1a67fa6..cc146f3 100644 --- a/examples/sea_orm_mysql.rs +++ b/examples/sea_orm_mysql.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, Reusable, SeaORMMySQLBackend, + DatabasePool, DatabasePoolBuilderTrait, ReusableConnectionPool, SeaORMMySQLBackend, }, PrivilegedMySQLConfig, }; @@ -15,7 +15,7 @@ mod tests { use tokio::sync::OnceCell; use tokio_shared_rt::test; - async fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + async fn get_connection_pool() -> ReusableConnectionPool<'static, SeaORMMySQLBackend> { static POOL: OnceCell> = OnceCell::const_new(); let db_pool = POOL @@ -49,7 +49,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/sea_orm_postgres.rs b/examples/sea_orm_postgres.rs index 9db8bd3..a5f1201 100644 --- a/examples/sea_orm_postgres.rs +++ b/examples/sea_orm_postgres.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, Reusable, SeaORMPostgresBackend, + DatabasePool, DatabasePoolBuilderTrait, ReusableConnectionPool, SeaORMPostgresBackend, }, PrivilegedPostgresConfig, }; @@ -15,7 +15,7 @@ mod tests { use tokio::sync::OnceCell; use tokio_shared_rt::test; - async fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + async fn get_connection_pool() -> ReusableConnectionPool<'static, SeaORMPostgresBackend> { static POOL: OnceCell> = OnceCell::const_new(); let db_pool = POOL @@ -49,7 +49,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/sqlx_mysql.rs b/examples/sqlx_mysql.rs index ef8e0d5..5f0fd98 100644 --- a/examples/sqlx_mysql.rs +++ b/examples/sqlx_mysql.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, Reusable, SqlxMySQLBackend, + DatabasePool, DatabasePoolBuilderTrait, ReusableConnectionPool, SqlxMySQLBackend, }, PrivilegedMySQLConfig, }; @@ -15,7 +15,7 @@ mod tests { use tokio::sync::OnceCell; use tokio_shared_rt::test; - async fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + async fn get_connection_pool() -> ReusableConnectionPool<'static, SqlxMySQLBackend> { static POOL: OnceCell> = OnceCell::const_new(); let db_pool = POOL @@ -43,7 +43,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/sqlx_postgres.rs b/examples/sqlx_postgres.rs index 0e939c9..6f228bc 100644 --- a/examples/sqlx_postgres.rs +++ b/examples/sqlx_postgres.rs @@ -6,7 +6,7 @@ mod tests { use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, Reusable, SqlxPostgresBackend, + DatabasePool, DatabasePoolBuilderTrait, ReusableConnectionPool, SqlxPostgresBackend, }, PrivilegedPostgresConfig, }; @@ -15,7 +15,7 @@ mod tests { use tokio::sync::OnceCell; use tokio_shared_rt::test; - async fn get_connection_pool() -> Reusable<'static, ConnectionPool> { + async fn get_connection_pool() -> ReusableConnectionPool<'static, SqlxPostgresBackend> { static POOL: OnceCell> = OnceCell::const_new(); let db_pool = POOL @@ -45,7 +45,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/examples/tokio_postgres.rs b/examples/tokio_postgres.rs index c0f1df2..b22dd4d 100644 --- a/examples/tokio_postgres.rs +++ b/examples/tokio_postgres.rs @@ -7,7 +7,7 @@ mod tests { use bb8::Pool; use db_pool::{ r#async::{ - ConnectionPool, DatabasePool, DatabasePoolBuilderTrait, Reusable, TokioPostgresBackend, + DatabasePool, DatabasePoolBuilderTrait, ReusableConnectionPool, TokioPostgresBackend, TokioPostgresBb8, }, PrivilegedPostgresConfig, @@ -17,7 +17,7 @@ mod tests { use tokio_shared_rt::test; async fn get_connection_pool( - ) -> Reusable<'static, ConnectionPool>> { + ) -> ReusableConnectionPool<'static, TokioPostgresBackend> { static POOL: OnceCell>> = OnceCell::const_new(); @@ -51,7 +51,7 @@ mod tests { }) .await; - db_pool.pull().await + db_pool.pull_immutable().await } async fn test() { diff --git a/src/async/backend/common/pool/diesel/mod.rs b/src/async/backend/common/pool/diesel/mod.rs index 27c0d29..f30db96 100644 --- a/src/async/backend/common/pool/diesel/mod.rs +++ b/src/async/backend/common/pool/diesel/mod.rs @@ -1,7 +1,7 @@ #[cfg(any(all(test, feature = "_diesel-async"), feature = "diesel-async-bb8"))] pub mod bb8; -#[cfg(feature = "diesel-async-deadpool")] -pub mod deadpool; -// #[cfg(feature = "diesel-async-mobc")] -// pub mod mobc; +// #[cfg(feature = "diesel-async-deadpool")] +// pub mod deadpool; +#[cfg(feature = "diesel-async-mobc")] +pub mod mobc; pub(in crate::r#async::backend) mod r#trait; diff --git a/src/async/backend/common/pool/tokio_postgres/mobc.rs b/src/async/backend/common/pool/tokio_postgres/mobc.rs index eddda58..d298079 100644 --- a/src/async/backend/common/pool/tokio_postgres/mobc.rs +++ b/src/async/backend/common/pool/tokio_postgres/mobc.rs @@ -14,7 +14,7 @@ use super::r#trait::TokioPostgresPoolAssociation; type Manager = PgConnectionManager; -/// [`tokio-postgres mobc`](https://github.com/importcjj/mobc-postgres) association +/// [`tokio-postgres mobc`](https://docs.rs/mobc-postgres/latest/mobc_postgres/) association /// # Example /// ``` /// use db_pool::r#async::{TokioPostgresBackend, TokioPostgresMobc}; diff --git a/src/async/backend/common/pool/tokio_postgres/mod.rs b/src/async/backend/common/pool/tokio_postgres/mod.rs index 2e79f06..5fde0d9 100644 --- a/src/async/backend/common/pool/tokio_postgres/mod.rs +++ b/src/async/backend/common/pool/tokio_postgres/mod.rs @@ -1,7 +1,7 @@ #[cfg(any(all(test, feature = "tokio-postgres"), feature = "tokio-postgres-bb8"))] pub mod bb8; -#[cfg(feature = "tokio-postgres-deadpool")] -pub mod deadpool; +// #[cfg(feature = "tokio-postgres-deadpool")] +// pub mod deadpool; #[cfg(feature = "tokio-postgres-mobc")] pub mod mobc; pub(in crate::r#async::backend) mod r#trait; diff --git a/src/async/backend/mod.rs b/src/async/backend/mod.rs index 7a5fb6a..0bcca3d 100644 --- a/src/async/backend/mod.rs +++ b/src/async/backend/mod.rs @@ -10,14 +10,14 @@ pub(crate) use error::Error; #[cfg(feature = "diesel-async-bb8")] pub use common::pool::diesel::bb8::DieselBb8; -#[cfg(feature = "diesel-async-deadpool")] -pub use common::pool::diesel::deadpool::DieselDeadpool; -// #[cfg(feature = "diesel-async-mobc")] -// pub use common::pool::diesel::mobc::DieselMobc; +// #[cfg(feature = "diesel-async-deadpool")] +// pub use common::pool::diesel::deadpool::DieselDeadpool; +#[cfg(feature = "diesel-async-mobc")] +pub use common::pool::diesel::mobc::DieselMobc; #[cfg(feature = "tokio-postgres-bb8")] pub use common::pool::tokio_postgres::bb8::TokioPostgresBb8; -#[cfg(feature = "tokio-postgres-deadpool")] -pub use common::pool::tokio_postgres::deadpool::TokioPostgresDeadpool; +// #[cfg(feature = "tokio-postgres-deadpool")] +// pub use common::pool::tokio_postgres::deadpool::TokioPostgresDeadpool; #[cfg(feature = "tokio-postgres-mobc")] pub use common::pool::tokio_postgres::mobc::TokioPostgresMobc; #[cfg(feature = "diesel-async-mysql")] diff --git a/src/async/backend/mysql/diesel.rs b/src/async/backend/mysql/diesel.rs index 29ff3ce..b6e9188 100644 --- a/src/async/backend/mysql/diesel.rs +++ b/src/async/backend/mysql/diesel.rs @@ -228,15 +228,22 @@ impl> Backend for DieselAsyncMySQ async fn create( &self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result> { - MySQLBackendWrapper::new(self).create(db_id).await + MySQLBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { MySQLBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { + async fn drop( + &self, + db_id: uuid::Uuid, + _is_restricted: bool, + ) -> Result<(), BError> { MySQLBackendWrapper::new(self).drop(db_id).await } } @@ -257,7 +264,16 @@ mod tests { common::statement::mysql::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - r#async::{backend::common::pool::diesel::bb8::DieselBb8, db_pool::DatabasePoolBuilder}, + r#async::{ + backend::{ + common::pool::diesel::bb8::DieselBb8, + mysql::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_pool_drops_created_unrestricted_database, + }, + }, + db_pool::DatabasePoolBuilder, + }, tests::get_privileged_mysql_config, }; @@ -265,7 +281,7 @@ mod tests { super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, MySQLDropLock, }, DieselAsyncMySQLBackend, @@ -318,6 +334,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).await.drop_previous_databases(false); @@ -331,9 +353,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { let backend = create_backend(true).await.drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -354,7 +382,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all( @@ -403,7 +431,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; let conn = &mut conn_pool.get().await.unwrap(); // DDL statements must fail @@ -420,6 +448,33 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).await.drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -431,7 +486,7 @@ mod tests { // fetch connection pools the first time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -459,7 +514,7 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -477,8 +532,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false).await; + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false).await; - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/mysql/sea_orm.rs b/src/async/backend/mysql/sea_orm.rs index e27c6ff..f5aad9b 100644 --- a/src/async/backend/mysql/sea_orm.rs +++ b/src/async/backend/mysql/sea_orm.rs @@ -272,15 +272,21 @@ impl Backend for SeaORMMySQLBackend { MySQLBackendWrapper::new(self).init().await } - async fn create(&self, db_id: uuid::Uuid) -> Result { - MySQLBackendWrapper::new(self).create(db_id).await + async fn create( + &self, + db_id: uuid::Uuid, + restrict_privileges: bool, + ) -> Result { + MySQLBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { MySQLBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { + async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> { MySQLBackendWrapper::new(self).drop(db_id).await } } @@ -301,7 +307,14 @@ mod tests { common::statement::mysql::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - r#async::db_pool::DatabasePoolBuilder, + r#async::{ + backend::mysql::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_pool_drops_created_restricted_databases, + test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, tests::get_privileged_mysql_config, }; @@ -309,8 +322,8 @@ mod tests { super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, - test_pool_drops_previous_databases, MySQLDropLock, + test_backend_drops_previous_databases, test_pool_drops_previous_databases, + MySQLDropLock, }, SeaORMMySQLBackend, }; @@ -363,6 +376,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).await.drop_previous_databases(false); @@ -376,9 +395,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { let backend = create_backend(true).await.drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -404,7 +429,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all(conns.iter().enumerate().map(|(i, conn)| async move { @@ -443,7 +468,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn = db_pool.pull().await; + let conn = db_pool.pull_immutable().await; // DDL statements must fail for stmt in DDL_STATEMENTS { @@ -459,6 +484,31 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).await.drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn = db_pool.create_mutable().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute_unprepared(stmt).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn = db_pool.create_mutable().await.unwrap(); + assert!(conn.execute_unprepared(stmt).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -470,7 +520,7 @@ mod tests { // fetch connection pools the first time { - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conns.iter().map(|conn| async move { @@ -491,7 +541,7 @@ mod tests { // fetch same connection pools a second time { - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conns.iter().map(|conn| async move { @@ -505,8 +555,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false).await; + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false).await; - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/mysql/sqlx.rs b/src/async/backend/mysql/sqlx.rs index 63cb649..6204955 100644 --- a/src/async/backend/mysql/sqlx.rs +++ b/src/async/backend/mysql/sqlx.rs @@ -201,15 +201,21 @@ impl Backend for SqlxMySQLBackend { MySQLBackendWrapper::new(self).init().await } - async fn create(&self, db_id: uuid::Uuid) -> Result { - MySQLBackendWrapper::new(self).create(db_id).await + async fn create( + &self, + db_id: uuid::Uuid, + restrict_privileges: bool, + ) -> Result { + MySQLBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { MySQLBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { + async fn drop(&self, db_id: uuid::Uuid, _is_restricted: bool) -> Result<(), BError> { MySQLBackendWrapper::new(self).drop(db_id).await } } @@ -229,7 +235,10 @@ mod tests { common::statement::mysql::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - r#async::db_pool::DatabasePoolBuilder, + r#async::{ + backend::mysql::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges, + db_pool::DatabasePoolBuilder, + }, tests::get_privileged_mysql_config, }; @@ -237,8 +246,9 @@ mod tests { super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, - test_pool_drops_previous_databases, MySQLDropLock, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, + test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases, + MySQLDropLock, }, SqlxMySQLBackend, }; @@ -285,6 +295,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -298,9 +314,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -326,7 +348,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all( @@ -373,7 +395,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; let conn = &mut conn_pool.acquire().await.unwrap(); // DDL statements must fail @@ -390,6 +412,33 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.acquire().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute(stmt).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.acquire().await.unwrap(); + assert!(conn.execute(stmt).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -401,7 +450,7 @@ mod tests { // fetch connection pools the first time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -429,7 +478,7 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -450,8 +499,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_databases() { let backend = create_backend(false); - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/mysql/trait.rs b/src/async/backend/mysql/trait.rs index 400e8c4..4221bf9 100644 --- a/src/async/backend/mysql/trait.rs +++ b/src/async/backend/mysql/trait.rs @@ -155,6 +155,7 @@ where pub(super) async fn create( &'backend self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result> { // Get database name based on UUID @@ -171,7 +172,7 @@ where .await .map_err(Into::into)?; - // Create CRUD user + // Create user self.execute_query(mysql::create_user(db_name, host).as_str(), conn) .await .map_err(Into::into)?; @@ -185,16 +186,27 @@ where .await .map_err(Into::into)?; - // Grant privileges to CRUD role - self.execute_query(mysql::grant_privileges(db_name, host).as_str(), conn) + if restrict_privileges { + // Grant privileges to restricted user + self.execute_query( + mysql::grant_restricted_privileges(db_name, host).as_str(), + conn, + ) .await .map_err(Into::into)?; + } else { + // Grant all privileges to database-unrestricted user + self.execute_query(mysql::grant_all_privileges(db_name, host).as_str(), conn) + .await + .map_err(Into::into)?; + } - // Create connection pool with CRUD role + // Create connection pool with attached user let pool = self .create_connection_pool(db_id) .await .map_err(Into::into)?; + Ok(pool) } @@ -258,7 +270,7 @@ where .await .map_err(Into::into)?; - // Drop CRUD role + // Drop attached user self.execute_query(mysql::drop_user(db_name, host).as_str(), conn) .await .map_err(Into::into)?; @@ -437,7 +449,7 @@ pub(super) mod tests { // database must exist after creating through backend backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); assert!(database_exists(db_name, conn).await); } @@ -461,6 +473,54 @@ pub(super) mod tests { .await; } + pub async fn test_backend_creates_database_with_unrestricted_privileges(backend: impl Backend) { + async { + { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + // privileged operations + { + let conn_pool = get_privileged_connection_pool().await; + let conn = &mut conn_pool.get().await.unwrap(); + + // database must not exist + assert!(!database_exists(db_name, conn).await); + + // database must exist after creating through backend + backend.init().await.unwrap(); + backend.create(db_id, false).await.unwrap(); + assert!(database_exists(db_name, conn).await); + } + + // DML statements must succeed + { + let conn_pool = create_restricted_connection_pool(db_name).await; + let conn = &mut conn_pool.get().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + backend.create(db_id, false).await.unwrap(); + let conn_pool = create_restricted_connection_pool(db_name).await; + let conn = &mut conn_pool.get().await.unwrap(); + + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + .lock_read() + .await; + } + pub async fn test_backend_cleans_database_with_tables(backend: impl Backend) { const NUM_BOOKS: i64 = 3; @@ -470,7 +530,7 @@ pub(super) mod tests { async { backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); table! { book (id) { @@ -522,14 +582,14 @@ pub(super) mod tests { async { backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); backend.clean(db_id).await.unwrap(); } .lock_read() .await; } - pub async fn test_backend_drops_database(backend: impl Backend) { + pub async fn test_backend_drops_database(backend: impl Backend, restricted: bool) { let db_id = Uuid::new_v4(); let db_name = get_db_name(db_id); let db_name = db_name.as_str(); @@ -540,11 +600,11 @@ pub(super) mod tests { // database must exist backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, restricted).await.unwrap(); assert!(database_exists(db_name, conn).await); // database must not exist - backend.drop(db_id).await.unwrap(); + backend.drop(db_id, true).await.unwrap(); assert!(!database_exists(db_name, conn).await); } .lock_read() @@ -576,7 +636,7 @@ pub(super) mod tests { .await; } - pub async fn test_pool_drops_created_databases(backend: impl Backend) { + pub async fn test_pool_drops_created_restricted_databases(backend: impl Backend) { const NUM_DBS: i64 = 3; let conn_pool = get_privileged_connection_pool().await; @@ -589,7 +649,7 @@ pub(super) mod tests { assert_eq!(count_all_databases(conn).await, 0); // fetch connection pools - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // there must be databases assert_eq!(count_all_databases(conn).await, NUM_DBS); @@ -609,4 +669,35 @@ pub(super) mod tests { .lock_drop() .await; } + + pub async fn test_pool_drops_created_unrestricted_database(backend: impl Backend) { + let conn_pool = get_privileged_connection_pool().await; + let conn = &mut conn_pool.get().await.unwrap(); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); + + // fetch connection pool + let conn_pool = db_pool.create_mutable().await.unwrap(); + + // there must be a database + assert_eq!(count_all_databases(conn).await, 1); + + // must drop database + drop(conn_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); + + drop(db_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); + } + .lock_drop() + .await; + } } diff --git a/src/async/backend/postgres/diesel.rs b/src/async/backend/postgres/diesel.rs index 950646f..08f3f4e 100644 --- a/src/async/backend/postgres/diesel.rs +++ b/src/async/backend/postgres/diesel.rs @@ -145,7 +145,7 @@ impl<'pool, P: DieselPoolAssociation> PostgresBackend<'pool> P::get_connection(&self.default_pool).await } - async fn establish_database_connection( + async fn establish_privileged_database_connection( &self, db_id: Uuid, ) -> ConnectionResult { @@ -156,6 +156,20 @@ impl<'pool, P: DieselPoolAssociation> PostgresBackend<'pool> AsyncPgConnection::establish(database_url.as_str()).await } + async fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> ConnectionResult { + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + let database_url = self.privileged_config.restricted_database_connection_url( + db_name, + Some(db_name), + db_name, + ); + AsyncPgConnection::establish(database_url.as_str()).await + } + fn put_database_connection(&self, db_id: Uuid, conn: AsyncPgConnection) { self.db_conns.lock().insert(db_id, conn); } @@ -244,16 +258,25 @@ impl> Backend for DieselAsyncPostgre async fn create( &self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result> { - PostgresBackendWrapper::new(self).create(db_id).await + PostgresBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { PostgresBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { - PostgresBackendWrapper::new(self).drop(db_id).await + async fn drop( + &self, + db_id: uuid::Uuid, + is_restricted: bool, + ) -> Result<(), BError> { + PostgresBackendWrapper::new(self) + .drop(db_id, is_restricted) + .await } } @@ -277,15 +300,23 @@ mod tests { CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, }, - r#async::{backend::common::pool::diesel::bb8::DieselBb8, db_pool::DatabasePoolBuilder}, + r#async::{ + backend::{ + common::pool::diesel::bb8::DieselBb8, + postgres::r#trait::tests::test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, }; use super::{ super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, - test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, - test_pool_drops_previous_databases, PgDropLock, + test_backend_creates_database_with_restricted_privileges, + test_backend_creates_database_with_unrestricted_privileges, + test_backend_drops_database, test_backend_drops_previous_databases, + test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, + PgDropLock, }, DieselAsyncPostgresBackend, }; @@ -326,7 +357,7 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn drops_previous_databases() { + async fn backend_drops_previous_databases() { test_backend_drops_previous_databases( create_backend(false).await, create_backend(false).await.drop_previous_databases(true), @@ -341,6 +372,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).await.drop_previous_databases(false); @@ -354,9 +391,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { let backend = create_backend(true).await.drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -377,7 +420,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all( @@ -426,7 +469,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; let conn = &mut conn_pool.get().await.unwrap(); // DDL statements must fail @@ -443,6 +486,33 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).await.drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + assert!(sql_query(stmt).execute(conn).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -454,7 +524,7 @@ mod tests { // fetch connection pools the first time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -482,7 +552,7 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -500,8 +570,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false).await; + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false).await; - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/postgres/sea_orm.rs b/src/async/backend/postgres/sea_orm.rs index b3a7456..667b19e 100644 --- a/src/async/backend/postgres/sea_orm.rs +++ b/src/async/backend/postgres/sea_orm.rs @@ -145,7 +145,7 @@ impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend { Ok(self.default_pool.clone().into()) } - async fn establish_database_connection( + async fn establish_privileged_database_connection( &self, db_id: Uuid, ) -> Result { @@ -157,6 +157,21 @@ impl<'pool> PostgresBackend<'pool> for SeaORMPostgresBackend { Database::connect(opts).await.map_err(Into::into) } + async fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> Result { + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + let database_url = self.privileged_config.restricted_database_connection_url( + db_name, + Some(db_name), + db_name, + ); + let opts = ConnectOptions::new(database_url); + Database::connect(opts).await.map_err(Into::into) + } + fn put_database_connection(&self, db_id: Uuid, conn: DatabaseConnection) { self.db_conns.lock().insert(db_id, conn); } @@ -272,16 +287,24 @@ impl Backend for SeaORMPostgresBackend { PostgresBackendWrapper::new(self).init().await } - async fn create(&self, db_id: uuid::Uuid) -> Result { - PostgresBackendWrapper::new(self).create(db_id).await + async fn create( + &self, + db_id: uuid::Uuid, + restrict_privileges: bool, + ) -> Result { + PostgresBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { PostgresBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { - PostgresBackendWrapper::new(self).drop(db_id).await + async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> { + PostgresBackendWrapper::new(self) + .drop(db_id, is_restricted) + .await } } @@ -305,14 +328,20 @@ mod tests { CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, }, - r#async::db_pool::DatabasePoolBuilder, + r#async::{ + backend::postgres::r#trait::tests::{ + test_backend_drops_database, test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, }; use super::{ super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, - test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_creates_database_with_restricted_privileges, + test_backend_creates_database_with_unrestricted_privileges, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, PgDropLock, }, SeaORMPostgresBackend, @@ -369,6 +398,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).await.drop_previous_databases(false); @@ -382,9 +417,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { let backend = create_backend(true).await.drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -410,7 +451,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all(conns.iter().enumerate().map(|(i, conn)| async move { @@ -449,7 +490,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn = db_pool.pull().await; + let conn = db_pool.pull_immutable().await; // DDL statements must fail for stmt in DDL_STATEMENTS { @@ -465,6 +506,31 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).await.drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn = db_pool.create_mutable().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute_unprepared(stmt).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn = db_pool.create_mutable().await.unwrap(); + assert!(conn.execute_unprepared(stmt).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -476,7 +542,7 @@ mod tests { // fetch connection pools the first time { - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conns.iter().map(|conn| async move { @@ -497,7 +563,7 @@ mod tests { // fetch same connection pools a second time { - let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conns = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conns.iter().map(|conn| async move { @@ -511,8 +577,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false).await; + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false).await; - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/postgres/sqlx.rs b/src/async/backend/postgres/sqlx.rs index 1d8b2d3..ead2302 100644 --- a/src/async/backend/postgres/sqlx.rs +++ b/src/async/backend/postgres/sqlx.rs @@ -127,7 +127,7 @@ impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend { self.default_pool.acquire().await.map_err(Into::into) } - async fn establish_database_connection( + async fn establish_privileged_database_connection( &self, db_id: Uuid, ) -> Result { @@ -136,6 +136,21 @@ impl<'pool> PostgresBackend<'pool> for SqlxPostgresBackend { PgConnection::connect_with(&opts).await.map_err(Into::into) } + async fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> Result { + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + let opts = self + .privileged_opts + .clone() + .username(db_name) + .password(db_name) + .database(db_name); + PgConnection::connect_with(&opts).await.map_err(Into::into) + } + fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) { self.db_conns.lock().insert(db_id, conn); } @@ -205,16 +220,20 @@ impl Backend for SqlxPostgresBackend { PostgresBackendWrapper::new(self).init().await } - async fn create(&self, db_id: uuid::Uuid) -> Result { - PostgresBackendWrapper::new(self).create(db_id).await + async fn create(&self, db_id: uuid::Uuid, restrict_privileges: bool) -> Result { + PostgresBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { PostgresBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { - PostgresBackendWrapper::new(self).drop(db_id).await + async fn drop(&self, db_id: uuid::Uuid, is_restricted: bool) -> Result<(), BError> { + PostgresBackendWrapper::new(self) + .drop(db_id, is_restricted) + .await } } @@ -233,14 +252,20 @@ mod tests { common::statement::postgres::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - r#async::db_pool::DatabasePoolBuilder, + r#async::{ + backend::postgres::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_backend_drops_database, test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, }; use super::{ super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, - test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_creates_database_with_restricted_privileges, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, PgDropLock, }, SqlxPostgresBackend, @@ -289,6 +314,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -302,9 +333,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -330,7 +367,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all( @@ -377,7 +414,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; let conn = &mut conn_pool.acquire().await.unwrap(); // DDL statements must fail @@ -394,6 +431,33 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.acquire().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute(stmt).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.acquire().await.unwrap(); + assert!(conn.execute(stmt).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -405,7 +469,7 @@ mod tests { // fetch connection pools the first time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -433,7 +497,7 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -454,8 +518,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false); - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/postgres/tokio_postgres.rs b/src/async/backend/postgres/tokio_postgres.rs index 86874e4..12ec16c 100644 --- a/src/async/backend/postgres/tokio_postgres.rs +++ b/src/async/backend/postgres/tokio_postgres.rs @@ -137,7 +137,10 @@ impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPos P::get_connection(&self.default_pool).await } - async fn establish_database_connection(&self, db_id: Uuid) -> Result { + async fn establish_privileged_database_connection( + &self, + db_id: Uuid, + ) -> Result { let mut config = self.privileged_config.clone(); let db_name = get_db_name(db_id); config.dbname(db_name.as_str()); @@ -146,6 +149,19 @@ impl<'pool, P: TokioPostgresPoolAssociation> PostgresBackend<'pool> for TokioPos Ok(client) } + async fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> Result { + let mut config = self.privileged_config.clone(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + config.user(db_name).password(db_name).dbname(db_name); + let (client, connection) = config.connect(NoTls).await?; + tokio::spawn(connection); + Ok(client) + } + fn put_database_connection(&self, db_id: Uuid, conn: Client) { self.db_conns.lock().insert(db_id, conn); } @@ -217,16 +233,25 @@ impl Backend for TokioPostgresBackend

{ async fn create( &self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result> { - PostgresBackendWrapper::new(self).create(db_id).await + PostgresBackendWrapper::new(self) + .create(db_id, restrict_privileges) + .await } async fn clean(&self, db_id: uuid::Uuid) -> Result<(), BError> { PostgresBackendWrapper::new(self).clean(db_id).await } - async fn drop(&self, db_id: uuid::Uuid) -> Result<(), BError> { - PostgresBackendWrapper::new(self).drop(db_id).await + async fn drop( + &self, + db_id: uuid::Uuid, + is_restricted: bool, + ) -> Result<(), BError> { + PostgresBackendWrapper::new(self) + .drop(db_id, is_restricted) + .await } } @@ -244,7 +269,13 @@ mod tests { CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, r#async::{ - backend::common::pool::tokio_postgres::bb8::TokioPostgresBb8, + backend::{ + common::pool::tokio_postgres::bb8::TokioPostgresBb8, + postgres::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_backend_drops_database, test_pool_drops_created_unrestricted_database, + }, + }, db_pool::DatabasePoolBuilder, }, }; @@ -252,8 +283,8 @@ mod tests { use super::{ super::r#trait::tests::{ test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, - test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_creates_database_with_restricted_privileges, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, PgDropLock, }, TokioPostgresBackend, @@ -299,6 +330,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(backend).await; } + #[test(flavor = "multi_thread", shared)] + async fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(backend).await; + } + #[test(flavor = "multi_thread", shared)] async fn backend_cleans_database_with_tables() { let backend = create_backend(true).await.drop_previous_databases(false); @@ -312,9 +349,15 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn backend_drops_database() { + async fn backend_drops_restricted_database() { let backend = create_backend(true).await.drop_previous_databases(false); - test_backend_drops_database(backend).await; + test_backend_drops_database(backend, true).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn backend_drops_unrestricted_database() { + let backend = create_backend(true).await.drop_previous_databases(false); + test_backend_drops_database(backend, false).await; } #[test(flavor = "multi_thread", shared)] @@ -335,7 +378,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // insert single row into each database join_all( @@ -385,7 +428,7 @@ mod tests { async { let db_pool = backend.create_database_pool().await.unwrap(); - let conn_pool = db_pool.pull().await; + let conn_pool = db_pool.pull_immutable().await; let conn = &mut conn_pool.get().await.unwrap(); // DDL statements must fail @@ -402,6 +445,33 @@ mod tests { .await; } + #[test(flavor = "multi_thread", shared)] + async fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).await.drop_previous_databases(false); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute(stmt, &[]).await.is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().await.unwrap(); + let conn = &mut conn_pool.get().await.unwrap(); + assert!(conn.execute(stmt, &[]).await.is_ok()); + } + } + .lock_read() + .await; + } + #[test(flavor = "multi_thread", shared)] async fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -413,7 +483,7 @@ mod tests { // fetch connection pools the first time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -440,7 +510,7 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // databases must be empty join_all(conn_pools.iter().map(|conn_pool| async move { @@ -461,8 +531,14 @@ mod tests { } #[test(flavor = "multi_thread", shared)] - async fn pool_drops_created_databases() { + async fn pool_drops_created_restricted_databases() { + let backend = create_backend(false).await; + test_pool_drops_created_restricted_databases(backend).await; + } + + #[test(flavor = "multi_thread", shared)] + async fn pool_drops_created_unrestricted_database() { let backend = create_backend(false).await; - test_pool_drops_created_databases(backend).await; + test_pool_drops_created_unrestricted_database(backend).await; } } diff --git a/src/async/backend/postgres/trait.rs b/src/async/backend/postgres/trait.rs index 81b4539..c14b10f 100644 --- a/src/async/backend/postgres/trait.rs +++ b/src/async/backend/postgres/trait.rs @@ -64,7 +64,11 @@ pub(super) trait PostgresBackend<'pool>: Send + Sync + 'static { async fn get_default_connection(&'pool self) -> Result; - async fn establish_database_connection( + async fn establish_privileged_database_connection( + &self, + db_id: Uuid, + ) -> Result; + async fn establish_restricted_database_connection( &self, db_id: Uuid, ) -> Result; @@ -158,46 +162,47 @@ where pub(super) async fn create( &'backend self, db_id: Uuid, + restrict_privileges: bool, ) -> Result> { // Get database name based on UUID let db_name = get_db_name(db_id); let db_name = db_name.as_str(); - { - // Get connection to default database as privileged user - let conn = &mut self.get_default_connection().await.map_err(Into::into)?; + // Get connection to default database as privileged user + let default_conn = &mut self.get_default_connection().await.map_err(Into::into)?; - // Create database - self.execute_query(postgres::create_database(db_name).as_str(), conn) - .await - .map_err(Into::into)?; + // Create database + self.execute_query(postgres::create_database(db_name).as_str(), default_conn) + .await + .map_err(Into::into)?; - // Create CRUD role - self.execute_query(postgres::create_role(db_name).as_str(), conn) - .await - .map_err(Into::into)?; - } + // Create role + self.execute_query(postgres::create_role(db_name).as_str(), default_conn) + .await + .map_err(Into::into)?; - { + if restrict_privileges { // Connect to database as privileged user let conn = self - .establish_database_connection(db_id) + .establish_privileged_database_connection(db_id) .await .map_err(Into::into)?; - // Create entities + // Create entities as privileged user let mut conn = self.create_entities(conn).await; - // Grant privileges to CRUD role + // Grant table privileges to restricted role self.execute_query( - postgres::grant_table_privileges(db_name).as_str(), + postgres::grant_restricted_table_privileges(db_name).as_str(), &mut conn, ) .await .map_err(Into::into)?; + + // Grant sequence privileges to restricted role self.execute_query( - postgres::grant_sequence_privileges(db_name).as_str(), + postgres::grant_restricted_sequence_privileges(db_name).as_str(), &mut conn, ) .await @@ -205,13 +210,31 @@ where // Store database connection for reuse when cleaning self.put_database_connection(db_id, conn); - } + } else { + // Grant database ownership to database-unrestricted role + self.execute_query( + postgres::grant_database_ownership(db_name, db_name).as_str(), + default_conn, + ) + .await + .map_err(Into::into)?; - // Create connection pool with CRUD role + // Connect to database as database-unrestricted user + let conn = self + .establish_restricted_database_connection(db_id) + .await + .map_err(Into::into)?; + + // Create entities as database-unrestricted user + let _ = self.create_entities(conn).await; + }; + + // Create connection pool with attached role let pool = self .create_connection_pool(db_id) .await .map_err(Into::into)?; + Ok(pool) } @@ -245,10 +268,11 @@ where pub(super) async fn drop( &'backend self, db_id: Uuid, + is_restricted: bool, ) -> Result<(), BackendError> { // Drop privileged connection to database - { + if is_restricted { self.get_database_connection(db_id); } @@ -264,7 +288,7 @@ where .await .map_err(Into::into)?; - // Drop CRUD role + // Drop attached role self.execute_query(postgres::drop_role(db_name).as_str(), conn) .await .map_err(Into::into)?; @@ -282,7 +306,10 @@ pub(super) mod tests { use diesel_async::{ pooled_connection::AsyncDieselConnectionManager, AsyncPgConnection, RunQueryDsl, }; - use futures::{future::join_all, Future}; + use futures::{ + future::{join_all, try_join_all}, + Future, + }; use tokio::sync::OnceCell; use uuid::Uuid; @@ -454,7 +481,7 @@ pub(super) mod tests { // database must exist after creating through backend backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); assert!(database_exists(db_name, conn).await); } @@ -478,6 +505,60 @@ pub(super) mod tests { .await; } + pub async fn test_backend_creates_database_with_unrestricted_privileges(backend: impl Backend) { + async { + { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + // privileged operations + { + let conn_pool = get_privileged_connection_pool().await; + let conn = &mut conn_pool.get().await.unwrap(); + + // database must not exist + assert!(!database_exists(db_name, conn).await); + + // database must exist after creating through backend + backend.init().await.unwrap(); + backend.create(db_id, false).await.unwrap(); + assert!(database_exists(db_name, conn).await); + } + + // DML statements must succeed + { + let conn_pool = create_restricted_connection_pool(db_name).await; + let conn = &mut conn_pool.get().await.unwrap(); + for stmt in DML_STATEMENTS { + let result = sql_query(stmt).execute(conn).await; + assert!(result.is_ok()); + } + } + } + + // DDL statements must succeed + try_join_all(DDL_STATEMENTS.iter().map(|stmt| { + let backend = &backend; + async move { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + backend.create(db_id, false).await.unwrap(); + let conn_pool = create_restricted_connection_pool(db_name).await; + let conn = &mut conn_pool.get().await.unwrap(); + + sql_query(*stmt).execute(conn).await + } + })) + .await + .unwrap(); + } + .lock_read() + .await; + } + pub async fn test_backend_cleans_database_with_tables(backend: impl Backend) { const NUM_BOOKS: i64 = 3; @@ -487,7 +568,7 @@ pub(super) mod tests { async { backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); let conn_pool = &mut create_restricted_connection_pool(db_name).await; let conn = &mut conn_pool.get().await.unwrap(); @@ -517,14 +598,14 @@ pub(super) mod tests { async { backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, true).await.unwrap(); backend.clean(db_id).await.unwrap(); } .lock_read() .await; } - pub async fn test_backend_drops_database(backend: impl Backend) { + pub async fn test_backend_drops_database(backend: impl Backend, restricted: bool) { let db_id = Uuid::new_v4(); let db_name = get_db_name(db_id); let db_name = db_name.as_str(); @@ -535,11 +616,11 @@ pub(super) mod tests { async { // database must exist backend.init().await.unwrap(); - backend.create(db_id).await.unwrap(); + backend.create(db_id, restricted).await.unwrap(); assert!(database_exists(db_name, conn).await); // database must not exist - backend.drop(db_id).await.unwrap(); + backend.drop(db_id, restricted).await.unwrap(); assert!(!database_exists(db_name, conn).await); } .lock_read() @@ -571,35 +652,66 @@ pub(super) mod tests { .await; } - pub async fn test_pool_drops_created_databases(backend: impl Backend) { + pub async fn test_pool_drops_created_restricted_databases(backend: impl Backend) { const NUM_DBS: i64 = 3; - let privileged_conn_pool = get_privileged_connection_pool().await; - let privileged_conn = &mut privileged_conn_pool.get().await.unwrap(); + let conn_pool = get_privileged_connection_pool().await; + let conn = &mut conn_pool.get().await.unwrap(); async { let db_pool = backend.create_database_pool().await.unwrap(); // there must be no databases - assert_eq!(count_all_databases(privileged_conn).await, 0); + assert_eq!(count_all_databases(conn).await, 0); // fetch connection pools - let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull())).await; + let conn_pools = join_all((0..NUM_DBS).map(|_| db_pool.pull_immutable())).await; // there must be databases - assert_eq!(count_all_databases(privileged_conn).await, NUM_DBS); + assert_eq!(count_all_databases(conn).await, NUM_DBS); // must release databases back to pool drop(conn_pools); // there must be databases - assert_eq!(count_all_databases(privileged_conn).await, NUM_DBS); + assert_eq!(count_all_databases(conn).await, NUM_DBS); // must drop databases drop(db_pool); // there must be no databases - assert_eq!(count_all_databases(privileged_conn).await, 0); + assert_eq!(count_all_databases(conn).await, 0); + } + .lock_drop() + .await; + } + + pub async fn test_pool_drops_created_unrestricted_database(backend: impl Backend) { + let conn_pool = get_privileged_connection_pool().await; + let conn = &mut conn_pool.get().await.unwrap(); + + async { + let db_pool = backend.create_database_pool().await.unwrap(); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); + + // fetch connection pool + let conn_pool = db_pool.create_mutable().await.unwrap(); + + // there must be a database + assert_eq!(count_all_databases(conn).await, 1); + + // must drop database + drop(conn_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); + + drop(db_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn).await, 0); } .lock_drop() .await; diff --git a/src/async/backend/trait.rs b/src/async/backend/trait.rs index 5eaa122..715b0d2 100644 --- a/src/async/backend/trait.rs +++ b/src/async/backend/trait.rs @@ -30,6 +30,7 @@ pub trait Backend: Sized + Send + Sync + 'static { async fn create( &self, db_id: Uuid, + restrict_privileges: bool, ) -> Result< Self::Pool, Error, @@ -45,5 +46,6 @@ pub trait Backend: Sized + Send + Sync + 'static { async fn drop( &self, db_id: Uuid, + is_restricted: bool, ) -> Result<(), Error>; } diff --git a/src/async/conn_pool.rs b/src/async/conn_pool.rs index 10cf4f4..6d3ecb3 100644 --- a/src/async/conn_pool.rs +++ b/src/async/conn_pool.rs @@ -4,53 +4,97 @@ use uuid::Uuid; use super::backend::{r#trait::Backend, Error as BackendError}; -/// Connection pool wrapper -pub struct ConnectionPool { +struct ConnectionPool { backend: Arc, db_id: Uuid, conn_pool: Option, + is_restricted: bool } -impl ConnectionPool { +impl Deref for ConnectionPool { + type Target = B::Pool; + + fn deref(&self) -> &Self::Target { + self.conn_pool + .as_ref() + .expect("conn_pool must always contain a [Some] value") + } +} + +impl Drop for ConnectionPool { + fn drop(&mut self) { + self.conn_pool = None; + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + (*self.backend) + .drop(self.db_id, self.is_restricted) + .await + .ok(); + }); + }); + } +} + +/// Reusable connection pool wrapper +pub struct ReusableConnectionPool(ConnectionPool); + +impl ReusableConnectionPool { pub(crate) async fn new( backend: Arc, ) -> Result> { let db_id = Uuid::new_v4(); - let conn_pool = backend.create(db_id).await?; + let conn_pool = backend.create(db_id, true).await?; - Ok(Self { + Ok(Self(ConnectionPool { backend, db_id, conn_pool: Some(conn_pool), - }) + is_restricted: true + })) } pub(crate) async fn clean( &mut self, ) -> Result<(), BackendError> { - self.backend.clean(self.db_id).await + self.0.backend.clean(self.0.db_id).await } } -impl Deref for ConnectionPool { +impl Deref for ReusableConnectionPool { type Target = B::Pool; fn deref(&self) -> &Self::Target { - self.conn_pool - .as_ref() - .expect("conn_pool must always contain a [Some] value") + &self.0 } } -impl Drop for ConnectionPool { - fn drop(&mut self) { - self.conn_pool = None; - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - (*self.backend).drop(self.db_id).await.ok(); - }); - }); +/// Single-use connection pool wrapper +pub struct SingleUseConnectionPool(ConnectionPool); + +impl SingleUseConnectionPool { + pub(crate) async fn new( + backend: Arc, + ) -> Result> + { + let db_id = Uuid::new_v4(); + let conn_pool = backend.create(db_id, false).await?; + + Ok(Self(ConnectionPool { + backend, + db_id, + conn_pool: Some(conn_pool), + is_restricted: false + }, + )) + } +} + +impl Deref for SingleUseConnectionPool { + type Target = B::Pool; + + fn deref(&self) -> &Self::Target { + &self.0 } } diff --git a/src/async/db_pool.rs b/src/async/db_pool.rs index 8a84776..cd58f9a 100644 --- a/src/async/db_pool.rs +++ b/src/async/db_pool.rs @@ -4,15 +4,23 @@ use async_trait::async_trait; use super::{ backend::{r#trait::Backend, Error}, - conn_pool::ConnectionPool, + conn_pool::{ReusableConnectionPool as ReusableConnectionPoolInner, SingleUseConnectionPool}, object_pool::{ObjectPool, Reusable}, }; +/// Wrapper for a reusable connection pool wrapped in a reusable object wrapper +pub type ReusableConnectionPool<'a, B> = Reusable<'a, ReusableConnectionPoolInner>; + /// Database pool -pub struct DatabasePool(ObjectPool>); +pub struct DatabasePool { + backend: Arc, + object_pool: ObjectPool>, +} impl DatabasePool { /// Pulls a reusable connection pool + /// + /// Privileges are granted only for ``SELECT``, ``INSERT``, ``UPDATE``, and ``DELETE`` operations. /// # Example /// ``` /// use bb8::Pool; @@ -47,14 +55,65 @@ impl DatabasePool { /// .unwrap(); /// /// let db_pool = backend.create_database_pool().await.unwrap(); - /// let conn_pool = db_pool.pull(); + /// let conn_pool = db_pool.pull_immutable(); /// } /// /// tokio_test::block_on(f()); /// ``` #[must_use] - pub async fn pull(&self) -> Reusable> { - self.0.pull().await + pub async fn pull_immutable(&self) -> ReusableConnectionPool { + self.object_pool.pull().await + } + + /// Creates a single-use connection pool + /// + /// All privileges are granted. + /// # Example + /// ``` + /// use bb8::Pool; + /// use db_pool::{ + /// r#async::{DatabasePoolBuilderTrait, DieselAsyncPostgresBackend, DieselBb8}, + /// PrivilegedPostgresConfig, + /// }; + /// use diesel::sql_query; + /// use diesel_async::RunQueryDsl; + /// use dotenvy::dotenv; + /// + /// async fn f() { + /// dotenv().ok(); + /// + /// let config = PrivilegedPostgresConfig::from_env().unwrap(); + /// + /// let backend = DieselAsyncPostgresBackend::::new( + /// config, + /// || Pool::builder().max_size(10), + /// || Pool::builder().max_size(2), + /// move |mut conn| { + /// Box::pin(async { + /// sql_query("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)") + /// .execute(&mut conn) + /// .await + /// .unwrap(); + /// conn + /// }) + /// }, + /// ) + /// .await + /// .unwrap(); + /// + /// let db_pool = backend.create_database_pool().await.unwrap(); + /// let conn_pool = db_pool.create_mutable(); + /// } + /// + /// tokio_test::block_on(f()); + /// ``` + pub async fn create_mutable( + &self, + ) -> Result< + SingleUseConnectionPool, + Error, + > { + SingleUseConnectionPool::new(self.backend.clone()).await } } @@ -108,26 +167,32 @@ pub trait DatabasePoolBuilder: Backend { > { self.init().await?; let backend = Arc::new(self); - let object_pool = ObjectPool::new( - move || { - let backend = backend.clone(); - Box::pin(async { - ConnectionPool::new(backend) - .await - .expect("connection pool creation must succeed") - }) - }, - |mut conn_pool| { - Box::pin(async { - conn_pool - .clean() - .await - .expect("connection pool cleaning must succeed"); - conn_pool - }) - }, - ); - Ok(DatabasePool(object_pool)) + let object_pool = { + let backend = backend.clone(); + ObjectPool::new( + move || { + let backend = backend.clone(); + Box::pin(async { + ReusableConnectionPoolInner::new(backend) + .await + .expect("connection pool creation must succeed") + }) + }, + |mut conn_pool| { + Box::pin(async { + conn_pool + .clean() + .await + .expect("connection pool cleaning must succeed"); + conn_pool + }) + }, + ) + }; + Ok(DatabasePool { + backend, + object_pool, + }) } } diff --git a/src/async/mod.rs b/src/async/mod.rs index 8f465c1..40e1ad7 100644 --- a/src/async/mod.rs +++ b/src/async/mod.rs @@ -5,7 +5,8 @@ mod object_pool; mod wrapper; pub use backend::*; -pub use conn_pool::ConnectionPool; -pub use db_pool::{DatabasePool, DatabasePoolBuilder as DatabasePoolBuilderTrait}; -pub use object_pool::Reusable; +pub use conn_pool::SingleUseConnectionPool; +pub use db_pool::{ + DatabasePool, DatabasePoolBuilder as DatabasePoolBuilderTrait, ReusableConnectionPool, +}; pub use wrapper::PoolWrapper; diff --git a/src/async/wrapper.rs b/src/async/wrapper.rs index 06b588f..9f98bde 100644 --- a/src/async/wrapper.rs +++ b/src/async/wrapper.rs @@ -1,13 +1,13 @@ use std::ops::Deref; -use super::{backend::r#trait::Backend, conn_pool::ConnectionPool, object_pool::Reusable}; +use super::{backend::r#trait::Backend, db_pool::ReusableConnectionPool}; /// Connection pool wrapper to facilitate the use of pools in code under test and reusable pools in tests pub enum PoolWrapper { /// Connection pool used in code under test Pool(B::Pool), /// Reusable connection pool used in tests - ReusablePool(Reusable<'static, ConnectionPool>), + ReusablePool(ReusableConnectionPool<'static, B>), } impl Deref for PoolWrapper { diff --git a/src/common/statement/mysql.rs b/src/common/statement/mysql.rs index b206544..bb3d987 100644 --- a/src/common/statement/mysql.rs +++ b/src/common/statement/mysql.rs @@ -1,5 +1,3 @@ -pub const USE_DEFAULT_DATABASE: &str = "USE information_schema"; - #[allow(dead_code)] pub const GET_DATABASE_NAMES: &str = "SELECT schema_name FROM information_schema.schemata WHERE schema_name LIKE 'db_pool_%';"; @@ -7,6 +5,8 @@ pub const GET_DATABASE_NAMES: &str = pub const TURN_OFF_FOREIGN_KEY_CHECKS: &str = "SET FOREIGN_KEY_CHECKS = 0"; pub const TURN_ON_FOREIGN_KEY_CHECKS: &str = "SET FOREIGN_KEY_CHECKS = 1"; +pub const USE_DEFAULT_DATABASE: &str = "USE information_schema"; + pub fn create_database(db_name: &str) -> String { format!("CREATE DATABASE {db_name}") } @@ -19,7 +19,11 @@ pub fn use_database(db_name: &str) -> String { format!("USE {db_name}") } -pub fn grant_privileges(db_name: &str, host: &str) -> String { +pub fn grant_all_privileges(db_name: &str, host: &str) -> String { + format!("GRANT ALL PRIVILEGES ON {db_name}.* TO {db_name}@{host}") +} + +pub fn grant_restricted_privileges(db_name: &str, host: &str) -> String { format!("GRANT SELECT, INSERT, UPDATE, DELETE ON {db_name}.* TO {db_name}@{host}") } diff --git a/src/common/statement/postgres.rs b/src/common/statement/postgres.rs index cfd3b71..0af7f04 100644 --- a/src/common/statement/postgres.rs +++ b/src/common/statement/postgres.rs @@ -13,11 +13,15 @@ pub fn create_role(name: &str) -> String { format!("CREATE ROLE {name} WITH LOGIN PASSWORD '{name}'") } -pub fn grant_table_privileges(role_name: &str) -> String { +pub fn grant_database_ownership(db_name: &str, role_name: &str) -> String { + format!("ALTER DATABASE {db_name} OWNER to {role_name}") +} + +pub fn grant_restricted_table_privileges(role_name: &str) -> String { format!("GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO {role_name}") } -pub fn grant_sequence_privileges(role_name: &str) -> String { +pub fn grant_restricted_sequence_privileges(role_name: &str) -> String { format!("GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO {role_name}") } diff --git a/src/sync/backend/mysql/diesel.rs b/src/sync/backend/mysql/diesel.rs index a96a866..bef4289 100644 --- a/src/sync/backend/mysql/diesel.rs +++ b/src/sync/backend/mysql/diesel.rs @@ -187,15 +187,23 @@ impl Backend for DieselMySQLBackend { MySQLBackendWrapper::new(self).init() } - fn create(&self, db_id: Uuid) -> Result, BackendError> { - MySQLBackendWrapper::new(self).create(db_id) + fn create( + &self, + db_id: Uuid, + restrict_privileges: bool, + ) -> Result, BackendError> { + MySQLBackendWrapper::new(self).create(db_id, restrict_privileges) } fn clean(&self, db_id: Uuid) -> Result<(), BackendError> { MySQLBackendWrapper::new(self).clean(db_id) } - fn drop(&self, db_id: Uuid) -> Result<(), BackendError> { + fn drop( + &self, + db_id: Uuid, + _is_restricted: bool, + ) -> Result<(), BackendError> { MySQLBackendWrapper::new(self).drop(db_id) } } @@ -216,7 +224,13 @@ mod tests { common::statement::mysql::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - sync::db_pool::DatabasePoolBuilder, + sync::{ + backend::mysql::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, tests::get_privileged_mysql_config, }; @@ -225,7 +239,7 @@ mod tests { lock_read, test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, }, DieselMySQLBackend, @@ -272,6 +286,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(&backend); } + #[test] + fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(&backend); + } + #[test] fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -285,9 +305,15 @@ mod tests { } #[test] - fn backend_drops_database() { + fn backend_drops_restricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(&backend); + test_backend_drops_database(&backend, true); + } + + #[test] + fn backend_drops_unrestricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(&backend, false); } #[test] @@ -308,7 +334,9 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // insert single row into each database conn_pools.iter().enumerate().for_each(|(i, conn_pool)| { @@ -341,21 +369,43 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pool = db_pool.pull(); + let conn_pool = db_pool.pull_immutable(); let conn = &mut conn_pool.get().unwrap(); - // restricted operations - { - // DDL statements must fail - for stmt in DDL_STATEMENTS { - assert!(sql_query(stmt).execute(conn).is_err()); - } + // DDL statements must fail + for stmt in DDL_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_err()); + } - // DML statements must succeed + // DML statements must succeed + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + + #[test] + fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + let guard = lock_read(); + + let db_pool = backend.create_database_pool().unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); for stmt in DML_STATEMENTS { assert!(sql_query(stmt).execute(conn).is_ok()); } } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); + assert!(sql_query(stmt).execute(conn).is_ok()); + } } #[test] @@ -370,7 +420,9 @@ mod tests { // fetch connection pools the first time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -392,7 +444,9 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -403,8 +457,14 @@ mod tests { } #[test] - fn pool_drops_created_databases() { + fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend); + } + + #[test] + fn pool_drops_created_unrestricted_databases() { let backend = create_backend(false); - test_pool_drops_created_databases(backend); + test_pool_drops_created_unrestricted_database(backend); } } diff --git a/src/sync/backend/mysql/mysql.rs b/src/sync/backend/mysql/mysql.rs index 22f9a8f..d579c4e 100644 --- a/src/sync/backend/mysql/mysql.rs +++ b/src/sync/backend/mysql/mysql.rs @@ -156,15 +156,19 @@ impl Backend for MySQLBackend { MySQLBackendWrapper::new(self).init() } - fn create(&self, db_id: Uuid) -> Result, BackendError> { - MySQLBackendWrapper::new(self).create(db_id) + fn create( + &self, + db_id: Uuid, + restrict_privileges: bool, + ) -> Result, BackendError> { + MySQLBackendWrapper::new(self).create(db_id, restrict_privileges) } fn clean(&self, db_id: Uuid) -> Result<(), BackendError> { MySQLBackendWrapper::new(self).clean(db_id) } - fn drop(&self, db_id: Uuid) -> Result<(), BackendError> { + fn drop(&self, db_id: Uuid, _is_restricted: bool) -> Result<(), BackendError> { MySQLBackendWrapper::new(self).drop(db_id) } } @@ -180,7 +184,13 @@ mod tests { common::statement::mysql::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - sync::DatabasePoolBuilderTrait, + sync::{ + backend::mysql::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_pool_drops_created_unrestricted_database, + }, + DatabasePoolBuilderTrait, + }, tests::get_privileged_mysql_config, }; @@ -189,7 +199,7 @@ mod tests { lock_read, test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, }, MySQLBackend, @@ -223,6 +233,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(&backend); } + #[test] + fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(&backend); + } + #[test] fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -236,9 +252,15 @@ mod tests { } #[test] - fn backend_drops_database() { + fn backend_drops_restricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(&backend, true); + } + + #[test] + fn backend_drops_unrestricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(&backend); + test_backend_drops_database(&backend, false); } #[test] @@ -259,7 +281,9 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // insert single row into each database conn_pools.iter().enumerate().for_each(|(i, conn_pool)| { @@ -290,21 +314,43 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pool = db_pool.pull(); + let conn_pool = db_pool.pull_immutable(); let conn = &mut conn_pool.get().unwrap(); - // restricted operations - { - // DDL statements must fail - for stmt in DDL_STATEMENTS { - assert!(conn.query_drop(stmt).is_err()); - } + // DDL statements must fail + for stmt in DDL_STATEMENTS { + assert!(conn.query_drop(stmt).is_err()); + } + + // DML statements must succeed + for stmt in DML_STATEMENTS { + assert!(conn.query_drop(stmt).is_ok()); + } + } - // DML statements must succeed + #[test] + fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + let guard = lock_read(); + + let db_pool = backend.create_database_pool().unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); for stmt in DML_STATEMENTS { assert!(conn.query_drop(stmt).is_ok()); } } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); + assert!(conn.query_drop(stmt).is_ok()); + } } #[test] @@ -319,7 +365,9 @@ mod tests { // fetch connection pools the first time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -347,7 +395,9 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -363,8 +413,14 @@ mod tests { } #[test] - fn pool_drops_created_databases() { + fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend); + } + + #[test] + fn pool_drops_created_unrestricted_database() { let backend = create_backend(false); - test_pool_drops_created_databases(backend); + test_pool_drops_created_unrestricted_database(backend); } } diff --git a/src/sync/backend/mysql/trait.rs b/src/sync/backend/mysql/trait.rs index b0a8b7a..b2dbfde 100644 --- a/src/sync/backend/mysql/trait.rs +++ b/src/sync/backend/mysql/trait.rs @@ -91,6 +91,7 @@ impl<'a, B: MySQLBackend> MySQLBackendWrapper<'a, B> { pub(super) fn create( &self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result, BackendError> { // Get database name based on UUID let db_name = crate::util::get_db_name(db_id); @@ -105,7 +106,7 @@ impl<'a, B: MySQLBackend> MySQLBackendWrapper<'a, B> { self.execute(mysql::create_database(db_name).as_str(), conn) .map_err(Into::into)?; - // Create CRUD user + // Create user self.execute(mysql::create_user(db_name, host).as_str(), conn) .map_err(Into::into)?; @@ -116,12 +117,22 @@ impl<'a, B: MySQLBackend> MySQLBackendWrapper<'a, B> { self.execute(mysql::USE_DEFAULT_DATABASE, conn) .map_err(Into::into)?; - // Grant privileges to CRUD role - self.execute(mysql::grant_privileges(db_name, host).as_str(), conn) + if restrict_privileges { + // Grant privileges to restricted user + self.execute( + mysql::grant_restricted_privileges(db_name, host).as_str(), + conn, + ) .map_err(Into::into)?; + } else { + // Grant all privileges to database-unrestricted user + self.execute(mysql::grant_all_privileges(db_name, host).as_str(), conn) + .map_err(Into::into)?; + } - // Create connection pool with CRUD role + // Create connection pool with attached user let pool = self.create_connection_pool(db_id)?; + Ok(pool) } @@ -326,7 +337,7 @@ pub(super) mod tests { // database must exist after creating through backend backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); assert!(database_exists(db_name, conn)); } @@ -347,6 +358,52 @@ pub(super) mod tests { } } + pub fn test_backend_creates_database_with_unrestricted_privileges(backend: &impl Backend) { + let guard = lock_read(); + + { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + // privileged operations + { + let conn_pool = get_privileged_connection_pool(); + let conn = &mut conn_pool.get().unwrap(); + + // database must not exist + assert!(!database_exists(db_name, conn)); + + // database must exist after creating through backend + backend.init().unwrap(); + backend.create(db_id, false).unwrap(); + assert!(database_exists(db_name, conn)); + } + + // DML statements must succeed + { + let conn_pool = create_restricted_connection_pool(db_name); + let conn = &mut conn_pool.get().unwrap(); + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + backend.create(db_id, false).unwrap(); + let conn_pool = create_restricted_connection_pool(db_name); + let conn = &mut conn_pool.get().unwrap(); + + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + pub fn test_backend_cleans_database_with_tables(backend: &impl Backend) { const NUM_BOOKS: i64 = 3; @@ -357,7 +414,7 @@ pub(super) mod tests { let guard = lock_read(); backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); table! { book (id) { @@ -403,11 +460,11 @@ pub(super) mod tests { let guard = lock_read(); backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); backend.clean(db_id).unwrap(); } - pub fn test_backend_drops_database(backend: &impl Backend) { + pub fn test_backend_drops_database(backend: &impl Backend, restricted: bool) { let db_id = Uuid::new_v4(); let db_name = get_db_name(db_id); let db_name = db_name.as_str(); @@ -419,11 +476,11 @@ pub(super) mod tests { // database must exist backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, restricted).unwrap(); assert!(database_exists(db_name, conn)); // database must not exist - backend.drop(db_id).unwrap(); + backend.drop(db_id, restricted).unwrap(); assert!(!database_exists(db_name, conn)); } @@ -446,7 +503,7 @@ pub(super) mod tests { } } - pub fn test_pool_drops_created_databases(backend: impl Backend) { + pub fn test_pool_drops_created_restricted_databases(backend: impl Backend) { const NUM_DBS: i64 = 3; let conn_pool = get_privileged_connection_pool(); @@ -460,7 +517,9 @@ pub(super) mod tests { assert_eq!(count_all_databases(conn), 0); // fetch connection pools - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // there must be databases assert_eq!(count_all_databases(conn), NUM_DBS); @@ -477,4 +536,33 @@ pub(super) mod tests { // there must be no databases assert_eq!(count_all_databases(conn), 0); } + + pub fn test_pool_drops_created_unrestricted_database(backend: impl Backend) { + let conn_pool = get_privileged_connection_pool(); + let conn = &mut conn_pool.get().unwrap(); + + let guard = lock_drop(); + + let db_pool = backend.create_database_pool().unwrap(); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); + + // fetch connection pool + let conn_pool = db_pool.create_mutable().unwrap(); + + // there must be a database + assert_eq!(count_all_databases(conn), 1); + + // must drop database + drop(conn_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); + + drop(db_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); + } } diff --git a/src/sync/backend/postgres/diesel.rs b/src/sync/backend/postgres/diesel.rs index 6cb83ee..56d697c 100644 --- a/src/sync/backend/postgres/diesel.rs +++ b/src/sync/backend/postgres/diesel.rs @@ -108,7 +108,10 @@ impl PostgresBackend for DieselPostgresBackend { self.default_pool.get() } - fn establish_database_connection(&self, db_id: Uuid) -> ConnectionResult { + fn establish_privileged_database_connection( + &self, + db_id: Uuid, + ) -> ConnectionResult { let db_name = get_db_name(db_id); let database_url = self .privileged_config @@ -116,6 +119,20 @@ impl PostgresBackend for DieselPostgresBackend { PgConnection::establish(database_url.as_str()) } + fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> ConnectionResult { + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + let database_url = self.privileged_config.restricted_database_connection_url( + db_name, + Some(db_name), + db_name, + ); + PgConnection::establish(database_url.as_str()) + } + fn put_database_connection(&self, db_id: Uuid, conn: PgConnection) { self.db_conns.lock().insert(db_id, conn); } @@ -189,16 +206,24 @@ impl Backend for DieselPostgresBackend { PostgresBackendWrapper::new(self).init() } - fn create(&self, db_id: Uuid) -> Result, BackendError> { - PostgresBackendWrapper::new(self).create(db_id) + fn create( + &self, + db_id: Uuid, + restrict_privileges: bool, + ) -> Result, BackendError> { + PostgresBackendWrapper::new(self).create(db_id, restrict_privileges) } fn clean(&self, db_id: Uuid) -> Result<(), BackendError> { PostgresBackendWrapper::new(self).clean(db_id) } - fn drop(&self, db_id: Uuid) -> Result<(), BackendError> { - PostgresBackendWrapper::new(self).drop(db_id) + fn drop( + &self, + db_id: Uuid, + is_restricted: bool, + ) -> Result<(), BackendError> { + PostgresBackendWrapper::new(self).drop(db_id, is_restricted) } } @@ -222,7 +247,10 @@ mod tests { CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, }, - sync::db_pool::DatabasePoolBuilder, + sync::{ + backend::postgres::r#trait::tests::test_backend_creates_database_with_unrestricted_privileges, + db_pool::DatabasePoolBuilder, + }, }; use super::{ @@ -230,8 +258,8 @@ mod tests { lock_read, test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, - test_pool_drops_previous_databases, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, + test_pool_drops_created_unrestricted_database, test_pool_drops_previous_databases, }, DieselPostgresBackend, }; @@ -280,6 +308,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(&backend); } + #[test] + fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(&backend); + } + #[test] fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -293,9 +327,15 @@ mod tests { } #[test] - fn backend_drops_database() { + fn backend_drops_restricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(&backend); + test_backend_drops_database(&backend, true); + } + + #[test] + fn backend_drops_unrestricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(&backend, false); } #[test] @@ -316,7 +356,9 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // insert single row into each database conn_pools.iter().enumerate().for_each(|(i, conn_pool)| { @@ -349,21 +391,43 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pool = db_pool.pull(); + let conn_pool = db_pool.pull_immutable(); let conn = &mut conn_pool.get().unwrap(); - // restricted operations - { - // DDL statements must fail - for stmt in DDL_STATEMENTS { - assert!(sql_query(stmt).execute(conn).is_err()); - } + // DDL statements must fail + for stmt in DDL_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_err()); + } + + // DML statements must succeed + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + + #[test] + fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + let guard = lock_read(); + + let db_pool = backend.create_database_pool().unwrap(); - // DML statements must succeed + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); for stmt in DML_STATEMENTS { assert!(sql_query(stmt).execute(conn).is_ok()); } } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); + assert!(sql_query(stmt).execute(conn).is_ok()); + } } #[test] @@ -378,7 +442,9 @@ mod tests { // fetch connection pools the first time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -400,7 +466,9 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -411,8 +479,14 @@ mod tests { } #[test] - fn pool_drops_created_databases() { + fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend); + } + + #[test] + fn pool_drops_created_unrestricted_database() { let backend = create_backend(false); - test_pool_drops_created_databases(backend); + test_pool_drops_created_unrestricted_database(backend); } } diff --git a/src/sync/backend/postgres/postgres.rs b/src/sync/backend/postgres/postgres.rs index 3a5154a..3fea71c 100644 --- a/src/sync/backend/postgres/postgres.rs +++ b/src/sync/backend/postgres/postgres.rs @@ -106,13 +106,27 @@ impl PostgresBackendTrait for PostgresBackend { self.default_pool.get() } - fn establish_database_connection(&self, db_id: Uuid) -> Result { + fn establish_privileged_database_connection( + &self, + db_id: Uuid, + ) -> Result { let mut config = self.config.clone(); let db_name = get_db_name(db_id); config.dbname(db_name.as_str()); config.connect(NoTls).map_err(Into::into) } + fn establish_restricted_database_connection( + &self, + db_id: Uuid, + ) -> Result { + let mut config = self.config.clone(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + config.user(db_name).password(db_name).dbname(db_name); + config.connect(NoTls).map_err(Into::into) + } + fn put_database_connection(&self, db_id: Uuid, conn: Client) { self.db_conns.lock().insert(db_id, conn); } @@ -214,16 +228,21 @@ impl Backend for PostgresBackend { fn create( &self, db_id: Uuid, + restrict_privileges: bool, ) -> Result, BackendError> { - PostgresBackendWrapper::new(self).create(db_id) + PostgresBackendWrapper::new(self).create(db_id, restrict_privileges) } fn clean(&self, db_id: Uuid) -> Result<(), BackendError> { PostgresBackendWrapper::new(self).clean(db_id) } - fn drop(&self, db_id: Uuid) -> Result<(), BackendError> { - PostgresBackendWrapper::new(self).drop(db_id) + fn drop( + &self, + db_id: Uuid, + is_restricted: bool, + ) -> Result<(), BackendError> { + PostgresBackendWrapper::new(self).drop(db_id, is_restricted) } } @@ -238,7 +257,13 @@ mod tests { common::statement::postgres::tests::{ CREATE_ENTITIES_STATEMENTS, DDL_STATEMENTS, DML_STATEMENTS, }, - sync::db_pool::DatabasePoolBuilder, + sync::{ + backend::postgres::r#trait::tests::{ + test_backend_creates_database_with_unrestricted_privileges, + test_pool_drops_created_unrestricted_database, + }, + db_pool::DatabasePoolBuilder, + }, PrivilegedPostgresConfig, }; @@ -247,7 +272,7 @@ mod tests { lock_read, test_backend_cleans_database_with_tables, test_backend_cleans_database_without_tables, test_backend_creates_database_with_restricted_privileges, test_backend_drops_database, - test_backend_drops_previous_databases, test_pool_drops_created_databases, + test_backend_drops_previous_databases, test_pool_drops_created_restricted_databases, test_pool_drops_previous_databases, }, PostgresBackend, @@ -284,6 +309,12 @@ mod tests { test_backend_creates_database_with_restricted_privileges(&backend); } + #[test] + fn backend_creates_database_with_unrestricted_privileges() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_creates_database_with_unrestricted_privileges(&backend); + } + #[test] fn backend_cleans_database_with_tables() { let backend = create_backend(true).drop_previous_databases(false); @@ -297,9 +328,15 @@ mod tests { } #[test] - fn backend_drops_database() { + fn backend_drops_restricted_database() { let backend = create_backend(true).drop_previous_databases(false); - test_backend_drops_database(&backend); + test_backend_drops_database(&backend, true); + } + + #[test] + fn backend_drops_unrestricted_database() { + let backend = create_backend(true).drop_previous_databases(false); + test_backend_drops_database(&backend, false); } #[test] @@ -320,7 +357,9 @@ mod tests { let guard = lock_read(); let db_pool = backend.create_database_pool().unwrap(); - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // insert single row into each database conn_pools.iter().enumerate().for_each(|(i, conn_pool)| { @@ -354,7 +393,7 @@ mod tests { let db_pool = backend.create_database_pool().unwrap(); - let conn_pool = db_pool.pull(); + let conn_pool = db_pool.pull_immutable(); let conn = &mut conn_pool.get().unwrap(); // DDL statements must fail @@ -368,6 +407,31 @@ mod tests { } } + #[test] + fn pool_provides_unrestricted_databases() { + let backend = create_backend(true).drop_previous_databases(false); + + let guard = lock_read(); + + let db_pool = backend.create_database_pool().unwrap(); + + // DML statements must succeed + { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); + for stmt in DML_STATEMENTS { + assert!(conn.execute(stmt, &[]).is_ok()); + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let conn_pool = db_pool.create_mutable().unwrap(); + let conn = &mut conn_pool.get().unwrap(); + assert!(conn.execute(stmt, &[]).is_ok()); + } + } + #[test] fn pool_provides_clean_databases() { const NUM_DBS: i64 = 3; @@ -380,7 +444,9 @@ mod tests { // fetch connection pools the first time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -403,7 +469,9 @@ mod tests { // fetch same connection pools a second time { - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // databases must be empty for conn_pool in &conn_pools { @@ -419,8 +487,14 @@ mod tests { } #[test] - fn pool_drops_created_databases() { + fn pool_drops_created_restricted_databases() { + let backend = create_backend(false); + test_pool_drops_created_restricted_databases(backend); + } + + #[test] + fn pool_drops_created_unrestricted_database() { let backend = create_backend(false); - test_pool_drops_created_databases(backend); + test_pool_drops_created_unrestricted_database(backend); } } diff --git a/src/sync/backend/postgres/trait.rs b/src/sync/backend/postgres/trait.rs index a7d3d86..58f86a7 100644 --- a/src/sync/backend/postgres/trait.rs +++ b/src/sync/backend/postgres/trait.rs @@ -26,7 +26,11 @@ pub(super) trait PostgresBackend { fn get_default_connection( &self, ) -> Result, r2d2::Error>; - fn establish_database_connection( + fn establish_privileged_database_connection( + &self, + db_id: Uuid, + ) -> Result<::Connection, Self::ConnectionError>; + fn establish_restricted_database_connection( &self, db_id: Uuid, ) -> Result<::Connection, Self::ConnectionError>; @@ -98,6 +102,7 @@ impl<'a, B: PostgresBackend> PostgresBackendWrapper<'a, B> { pub(super) fn create( &self, db_id: uuid::Uuid, + restrict_privileges: bool, ) -> Result, BackendError> { // Get database name based on UUID let db_name = crate::util::get_db_name(db_id); @@ -111,7 +116,7 @@ impl<'a, B: PostgresBackend> PostgresBackendWrapper<'a, B> { self.execute_query(postgres::create_database(db_name).as_str(), conn) .map_err(Into::into)?; - // Create CRUD role + // Create role self.execute_query(postgres::create_role(db_name).as_str(), conn) .map_err(Into::into)?; } @@ -119,30 +124,50 @@ impl<'a, B: PostgresBackend> PostgresBackendWrapper<'a, B> { { // Connect to database as privileged user let mut conn = self - .establish_database_connection(db_id) + .establish_privileged_database_connection(db_id) .map_err(Into::into)?; - // Create entities - self.create_entities(&mut conn); + if restrict_privileges { + // Create entities as privileged user + self.create_entities(&mut conn); - // Grant privileges to CRUD role - self.execute_query( - postgres::grant_table_privileges(db_name).as_str(), - &mut conn, - ) - .map_err(Into::into)?; - self.execute_query( - postgres::grant_sequence_privileges(db_name).as_str(), - &mut conn, - ) - .map_err(Into::into)?; + // Grant table privileges to restricted role + self.execute_query( + postgres::grant_restricted_table_privileges(db_name).as_str(), + &mut conn, + ) + .map_err(Into::into)?; + + // Grant sequence privileges to restricted role + self.execute_query( + postgres::grant_restricted_sequence_privileges(db_name).as_str(), + &mut conn, + ) + .map_err(Into::into)?; + + // Store database connection for reuse when cleaning + self.put_database_connection(db_id, conn); + } else { + // Grant database ownership to database-unrestricted role + self.execute_query( + postgres::grant_database_ownership(db_name, db_name).as_str(), + &mut conn, + ) + .map_err(Into::into)?; - // Store database connection for reuse when cleaning - self.put_database_connection(db_id, conn); + // Connect to database as database-unrestricted user + let mut conn = self + .establish_restricted_database_connection(db_id) + .map_err(Into::into)?; + + // Create entities as database-unrestricted user + self.create_entities(&mut conn); + } } - // Create connection pool with CRUD role + // Create connection pool with attached role let pool = self.create_connection_pool(db_id)?; + Ok(pool) } @@ -174,9 +199,10 @@ impl<'a, B: PostgresBackend> PostgresBackendWrapper<'a, B> { pub(super) fn drop( &self, db_id: uuid::Uuid, + is_restricted: bool, ) -> Result<(), BackendError> { // Drop privileged connection to database - { + if is_restricted { self.get_database_connection(db_id); } @@ -191,7 +217,7 @@ impl<'a, B: PostgresBackend> PostgresBackendWrapper<'a, B> { self.execute_query(postgres::drop_database(db_name).as_str(), conn) .map_err(Into::into)?; - // Drop CRUD role + // Drop attached role self.execute_query(postgres::drop_role(db_name).as_str(), conn) .map_err(Into::into)?; @@ -328,7 +354,7 @@ pub(super) mod tests { // database must exist after creating through backend backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); assert!(database_exists(db_name, conn)); } @@ -349,6 +375,52 @@ pub(super) mod tests { } } + pub fn test_backend_creates_database_with_unrestricted_privileges(backend: &impl Backend) { + let guard = lock_read(); + + { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + // privileged operations + { + let conn_pool = get_privileged_connection_pool(); + let conn = &mut conn_pool.get().unwrap(); + + // database must not exist + assert!(!database_exists(db_name, conn)); + + // database must exist after creating through backend + backend.init().unwrap(); + backend.create(db_id, false).unwrap(); + assert!(database_exists(db_name, conn)); + } + + // DML statements must succeed + { + let conn_pool = &mut create_restricted_connection_pool(db_name); + let conn = &mut conn_pool.get().unwrap(); + for stmt in DML_STATEMENTS { + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + } + + // DDL statements must succeed + for stmt in DDL_STATEMENTS { + let db_id = Uuid::new_v4(); + let db_name = get_db_name(db_id); + let db_name = db_name.as_str(); + + backend.create(db_id, false).unwrap(); + let conn_pool = &mut create_restricted_connection_pool(db_name); + let conn = &mut conn_pool.get().unwrap(); + + assert!(sql_query(stmt).execute(conn).is_ok()); + } + } + pub fn test_backend_cleans_database_with_tables(backend: &impl Backend) { const NUM_BOOKS: i64 = 3; @@ -359,7 +431,7 @@ pub(super) mod tests { let guard = lock_read(); backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); let conn_pool = &mut create_restricted_connection_pool(db_name); let conn = &mut conn_pool.get().unwrap(); @@ -405,11 +477,11 @@ pub(super) mod tests { let guard = lock_read(); backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, true).unwrap(); backend.clean(db_id).unwrap(); } - pub fn test_backend_drops_database(backend: &impl Backend) { + pub fn test_backend_drops_database(backend: &impl Backend, restricted: bool) { let db_id = Uuid::new_v4(); let db_name = get_db_name(db_id); let db_name = db_name.as_str(); @@ -421,11 +493,11 @@ pub(super) mod tests { // database must exist backend.init().unwrap(); - backend.create(db_id).unwrap(); + backend.create(db_id, restricted).unwrap(); assert!(database_exists(db_name, conn)); // database must not exist - backend.drop(db_id).unwrap(); + backend.drop(db_id, restricted).unwrap(); assert!(!database_exists(db_name, conn)); } @@ -448,35 +520,66 @@ pub(super) mod tests { } } - pub fn test_pool_drops_created_databases(backend: impl Backend) { + pub fn test_pool_drops_created_restricted_databases(backend: impl Backend) { const NUM_DBS: i64 = 3; - let privileged_conn_pool = get_privileged_connection_pool(); - let privileged_conn = &mut privileged_conn_pool.get().unwrap(); + let conn_pool = get_privileged_connection_pool(); + let conn = &mut conn_pool.get().unwrap(); let guard = lock_drop(); let db_pool = backend.create_database_pool().unwrap(); // there must be no databases - assert_eq!(count_all_databases(privileged_conn), 0); + assert_eq!(count_all_databases(conn), 0); // fetch connection pools - let conn_pools = (0..NUM_DBS).map(|_| db_pool.pull()).collect::>(); + let conn_pools = (0..NUM_DBS) + .map(|_| db_pool.pull_immutable()) + .collect::>(); // there must be databases - assert_eq!(count_all_databases(privileged_conn), NUM_DBS); + assert_eq!(count_all_databases(conn), NUM_DBS); // must release databases back to pool drop(conn_pools); // there must be databases - assert_eq!(count_all_databases(privileged_conn), NUM_DBS); + assert_eq!(count_all_databases(conn), NUM_DBS); // must drop databases drop(db_pool); // there must be no databases - assert_eq!(count_all_databases(privileged_conn), 0); + assert_eq!(count_all_databases(conn), 0); + } + + pub fn test_pool_drops_created_unrestricted_database(backend: impl Backend) { + let conn_pool = get_privileged_connection_pool(); + let conn = &mut conn_pool.get().unwrap(); + + let guard = lock_drop(); + + let db_pool = backend.create_database_pool().unwrap(); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); + + // fetch connection pool + let conn_pool = db_pool.create_mutable().unwrap(); + + // there must be a database + assert_eq!(count_all_databases(conn), 1); + + // must drop database + drop(conn_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); + + drop(db_pool); + + // there must be no databases + assert_eq!(count_all_databases(conn), 0); } } diff --git a/src/sync/backend/trait.rs b/src/sync/backend/trait.rs index 48efcc7..a29a113 100644 --- a/src/sync/backend/trait.rs +++ b/src/sync/backend/trait.rs @@ -22,11 +22,16 @@ pub trait Backend: Sized + Send + Sync + 'static { fn create( &self, db_id: Uuid, + restrict_privileges: bool, ) -> Result, Error>; /// Cleans a database fn clean(&self, db_id: Uuid) -> Result<(), Error>; /// Drops a database - fn drop(&self, db_id: Uuid) -> Result<(), Error>; + fn drop( + &self, + db_id: Uuid, + is_restricted: bool, + ) -> Result<(), Error>; } diff --git a/src/sync/conn_pool.rs b/src/sync/conn_pool.rs index 7d3a368..05c5326 100644 --- a/src/sync/conn_pool.rs +++ b/src/sync/conn_pool.rs @@ -5,45 +5,84 @@ use uuid::Uuid; use super::backend::{r#trait::Backend, Error as BackendError}; -/// Connection pool wrapper -pub struct ConnectionPool { +struct ConnectionPool { backend: Arc, db_id: Uuid, conn_pool: Option>, + is_restricted: bool, } -impl ConnectionPool { +impl Deref for ConnectionPool { + type Target = Pool; + + fn deref(&self) -> &Self::Target { + self.conn_pool + .as_ref() + .expect("conn_pool must always contain a [Some] value") + } +} + +impl Drop for ConnectionPool { + fn drop(&mut self) { + self.conn_pool = None; + (*self.backend).drop(self.db_id, self.is_restricted).ok(); + } +} + +/// Reusable connection pool wrapper +pub struct ReusableConnectionPool(ConnectionPool); + +impl ReusableConnectionPool { pub(crate) fn new( backend: Arc, ) -> Result> { let db_id = Uuid::new_v4(); - let conn_pool = backend.create(db_id)?; + let conn_pool = backend.create(db_id, true)?; - Ok(Self { + Ok(Self(ConnectionPool { backend, db_id, conn_pool: Some(conn_pool), - }) + is_restricted: true, + })) } pub(crate) fn clean(&mut self) -> Result<(), BackendError> { - self.backend.clean(self.db_id) + self.0.backend.clean(self.0.db_id) } } -impl Deref for ConnectionPool { +impl Deref for ReusableConnectionPool { type Target = Pool; fn deref(&self) -> &Self::Target { - self.conn_pool - .as_ref() - .expect("conn_pool must always contain a [Some] value") + &self.0 } } -impl Drop for ConnectionPool { - fn drop(&mut self) { - self.conn_pool = None; - (*self.backend).drop(self.db_id).ok(); +/// Single-use connection pool wrapper +pub struct SingleUseConnectionPool(ConnectionPool); + +impl SingleUseConnectionPool { + pub(crate) fn new( + backend: Arc, + ) -> Result> { + let db_id = Uuid::new_v4(); + let conn_pool = backend.create(db_id, false)?; + + Ok(Self(ConnectionPool { + backend, + db_id, + conn_pool: Some(conn_pool), + is_restricted: false, + })) + } +} + +impl Deref for SingleUseConnectionPool { + type Target = Pool; + + fn deref(&self) -> &Self::Target { + &self.0 } } diff --git a/src/sync/db_pool.rs b/src/sync/db_pool.rs index 72fc15a..86375b1 100644 --- a/src/sync/db_pool.rs +++ b/src/sync/db_pool.rs @@ -2,15 +2,23 @@ use std::sync::Arc; use super::{ backend::{r#trait::Backend, Error}, - conn_pool::ConnectionPool, + conn_pool::{ReusableConnectionPool as ReusableConnectionPoolInner, SingleUseConnectionPool}, object_pool::{ObjectPool, Reusable}, }; +/// Wrapper for a reusable connection pool wrapped in a reusable object wrapper +pub type ReusableConnectionPool<'a, B> = Reusable<'a, ReusableConnectionPoolInner>; + /// Database pool -pub struct DatabasePool(ObjectPool>); +pub struct DatabasePool { + backend: Arc, + object_pool: ObjectPool>, +} impl DatabasePool { /// Pulls a reusable connection pool + /// + /// Privileges are granted only for ``SELECT``, ``INSERT``, ``UPDATE``, and ``DELETE`` operations. /// # Example /// ``` /// use db_pool::{ @@ -38,11 +46,49 @@ impl DatabasePool { /// .unwrap(); /// /// let db_pool = backend.create_database_pool().unwrap(); - /// let conn_pool = db_pool.pull(); + /// let conn_pool = db_pool.pull_immutable(); /// ``` #[must_use] - pub fn pull(&self) -> Reusable> { - self.0.pull() + pub fn pull_immutable(&self) -> Reusable> { + self.object_pool.pull() + } + + /// Creates a single-use connection pool + /// + /// All privileges are granted. + /// # Example + /// ``` + /// use db_pool::{ + /// sync::{DatabasePoolBuilderTrait, DieselPostgresBackend}, + /// PrivilegedPostgresConfig, + /// }; + /// use diesel::{sql_query, RunQueryDsl}; + /// use dotenvy::dotenv; + /// use r2d2::Pool; + /// + /// dotenv().ok(); + /// + /// let config = PrivilegedPostgresConfig::from_env().unwrap(); + /// + /// let backend = DieselPostgresBackend::new( + /// config, + /// || Pool::builder().max_size(10), + /// || Pool::builder().max_size(2), + /// move |conn| { + /// sql_query("CREATE TABLE book(id SERIAL PRIMARY KEY, title TEXT NOT NULL)") + /// .execute(conn) + /// .unwrap(); + /// }, + /// ) + /// .unwrap(); + /// + /// let db_pool = backend.create_database_pool().unwrap(); + /// let conn_pool = db_pool.create_mutable(); + /// ``` + pub fn create_mutable( + &self, + ) -> Result, Error> { + SingleUseConnectionPool::new(self.backend.clone()) } } @@ -82,18 +128,25 @@ pub trait DatabasePoolBuilder: Backend { ) -> Result, Error> { self.init()?; let backend = Arc::new(self); - let object_pool = ObjectPool::new( - move || { - let backend = backend.clone(); - ConnectionPool::new(backend).expect("connection pool creation must succeed") - }, - |conn_pool| { - conn_pool - .clean() - .expect("connection pool cleaning must succeed"); - }, - ); - Ok(DatabasePool(object_pool)) + let object_pool = { + let backend = backend.clone(); + ObjectPool::new( + move || { + let backend = backend.clone(); + ReusableConnectionPoolInner::new(backend) + .expect("connection pool creation must succeed") + }, + |conn_pool| { + conn_pool + .clean() + .expect("connection pool cleaning must succeed"); + }, + ) + }; + Ok(DatabasePool { + backend, + object_pool, + }) } } diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 7f1e8b0..0c4c900 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -5,7 +5,9 @@ mod object_pool; mod wrapper; pub use backend::*; -pub use conn_pool::ConnectionPool; -pub use db_pool::{DatabasePool, DatabasePoolBuilder as DatabasePoolBuilderTrait}; -pub use object_pool::{ObjectPool, Reusable}; +pub use conn_pool::SingleUseConnectionPool; +pub use db_pool::{ + DatabasePool, DatabasePoolBuilder as DatabasePoolBuilderTrait, ReusableConnectionPool, +}; +pub use object_pool::ObjectPool; pub use wrapper::PoolWrapper; diff --git a/src/sync/wrapper.rs b/src/sync/wrapper.rs index ed55882..65e4e83 100644 --- a/src/sync/wrapper.rs +++ b/src/sync/wrapper.rs @@ -2,14 +2,14 @@ use std::ops::Deref; use r2d2::Pool; -use super::{backend::r#trait::Backend, conn_pool::ConnectionPool, object_pool::Reusable}; +use super::{backend::r#trait::Backend, db_pool::ReusableConnectionPool}; /// Connection pool wrapper to facilitate the use of pools in code under test and reusable pools in tests pub enum PoolWrapper { /// Connection pool used in code under test Pool(Pool), /// Reusable connection pool used in tests - ReusablePool(Reusable<'static, ConnectionPool>), + ReusablePool(ReusableConnectionPool<'static, B>), } impl Deref for PoolWrapper {