Skip to content

Commit

Permalink
Execute ConnectionBuilder options outside of the apply_migrations f…
Browse files Browse the repository at this point in the history
…unction
  • Loading branch information
sfauvel committed Nov 20, 2024
1 parent ba0f4de commit 75ce762
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 41 deletions.
80 changes: 55 additions & 25 deletions internal/mithril-persistence/src/sqlite/connection_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ impl ConnectionBuilder {
)
})?;

let migrations = self.sql_migrations.clone();
self.apply_migrations(&connection, migrations)?;

Ok(connection)
}

/// Apply a list of migration to the connection.
pub fn apply_migrations(
self,
connection: &ConnectionThreadSafe,
sql_migrations: Vec<SqlMigration>,
) -> StdResult<()> {
let logger = self.base_logger.new_with_component_name::<Self>();
if self
.options
.contains(&ConnectionOptions::EnableWriteAheadLog)
Expand All @@ -120,30 +107,46 @@ impl ConnectionBuilder {
.with_context(|| "SQLite initialization: could not enable FOREIGN KEY support.")?;
}

let migrations = self.sql_migrations.clone();
self.apply_migrations(&connection, migrations)?;
if self
.options
.contains(&ConnectionOptions::ForceDisableForeignKeys)
{
debug!(logger, "Force disabling SQLite foreign key support");
connection
.execute("pragma foreign_keys=false")
.with_context(|| "SQLite initialization: could not disable FOREIGN KEY support.")?;
}
Ok(connection)
}

/// Apply a list of migration to the connection.
pub fn apply_migrations(
&self,
connection: &ConnectionThreadSafe,
sql_migrations: Vec<SqlMigration>,
) -> StdResult<()> {
let logger = self.base_logger.new_with_component_name::<Self>();

if sql_migrations.is_empty().not() {
// Check database migrations
debug!(logger, "Applying database migrations");
let mut db_checker =
DatabaseVersionChecker::new(self.base_logger, self.node_type, connection);
let mut db_checker = DatabaseVersionChecker::new(
self.base_logger.clone(),
self.node_type.clone(),
connection,
);

for migration in sql_migrations {
db_checker.add_migration(migration);
db_checker.add_migration(migration.clone());
}

db_checker
.apply()
.with_context(|| "Database migration error")?;
}

if self
.options
.contains(&ConnectionOptions::ForceDisableForeignKeys)
{
debug!(logger, "Force disabling SQLite foreign key support");
connection
.execute("pragma foreign_keys=false")
.with_context(|| "SQLite initialization: could not disable FOREIGN KEY support.")?;
}
Ok(())
}
}
Expand Down Expand Up @@ -290,4 +293,31 @@ mod tests {
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert_eq!(Value::Integer(false.into()), foreign_keys);
}

#[test]
fn test_apply_a_partial_migrations() {
let migrations = vec![
SqlMigration::new(1, "create table first(id integer);"),
SqlMigration::new(2, "create table second(id integer);"),
];

let connection = ConnectionBuilder::open_memory().build().unwrap();

assert!(connection.prepare("select * from first;").is_err());
assert!(connection.prepare("select * from second;").is_err());

ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations[0..1].to_vec())
.unwrap();

assert!(connection.prepare("select * from first;").is_ok());
assert!(connection.prepare("select * from second;").is_err());

ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations)
.unwrap();

assert!(connection.prepare("select * from first;").is_ok());
assert!(connection.prepare("select * from second;").is_ok());
}
}
11 changes: 3 additions & 8 deletions mithril-signer/src/database/tests/protocol_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use mithril_common::test_utils::fake_data;
use mithril_common::{crypto_helper::ProtocolInitializer, entities::Epoch};
use mithril_persistence::sqlite::{ConnectionBuilder, ConnectionExtensions, ConnectionOptions};
use mithril_persistence::sqlite::{ConnectionBuilder, ConnectionExtensions};

use crate::database::repository::ProtocolInitializerRepository;
use crate::database::test_helper::{main_db_connection, FakeStoreAdapter};
Expand Down Expand Up @@ -180,12 +180,7 @@ mod migration {
async fn should_migrate_data_from_adapter() {
let migrations = crate::database::migration::get_migrations();

// TODO: Do it in test_helper (it is done by build_main_db_connection)
fn create_connection_builder() -> ConnectionBuilder {
ConnectionBuilder::open_memory()
.with_options(&[ConnectionOptions::ForceDisableForeignKeys])
}
let connection = Arc::new(create_connection_builder().build().unwrap());
let connection = Arc::new(ConnectionBuilder::open_memory().build().unwrap());
let protocol_initializer_adapter =
FakeStoreAdapter::new(connection.clone(), "protocol_initializer");
// The adapter will create the table.
Expand All @@ -207,7 +202,7 @@ mod migration {
assert!(protocol_initializer_adapter.is_key_hash_exist("HashEpoch5"));

// We finish the migration
create_connection_builder()
ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations)
.unwrap();

Expand Down
11 changes: 3 additions & 8 deletions mithril-signer/src/database/tests/stake_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use mithril_common::entities::{Epoch, StakeDistribution};
use mithril_common::signable_builder::StakeDistributionRetriever;
use mithril_persistence::sqlite::{ConnectionBuilder, ConnectionOptions};
use mithril_persistence::sqlite::ConnectionBuilder;
use mithril_persistence::store::StakeStorer;

use crate::database::repository::StakePoolStore;
Expand Down Expand Up @@ -243,12 +243,7 @@ mod migration {
async fn should_migrate_data_from_adapter() {
let migrations = crate::database::migration::get_migrations();

// TODO: Do it in test_helper (it is done by build_main_db_connection)
fn create_connection_builder() -> ConnectionBuilder {
ConnectionBuilder::open_memory()
.with_options(&[ConnectionOptions::ForceDisableForeignKeys])
}
let connection = Arc::new(create_connection_builder().build().unwrap());
let connection = Arc::new(ConnectionBuilder::open_memory().build().unwrap());

// The adapter will create the table.
let stake_adapter = FakeStoreAdapter::new(connection.clone(), "stake");
Expand All @@ -271,7 +266,7 @@ mod migration {
assert!(stake_adapter.is_key_hash_exist("HashEpoch5"));

// We finish the migration
create_connection_builder()
ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations)
.unwrap();
assert!(connection.prepare("select * from stake;").is_err());
Expand Down

0 comments on commit 75ce762

Please sign in to comment.