diff --git a/Cargo.lock b/Cargo.lock index 10c9c055..2b03f422 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1291,6 +1291,7 @@ dependencies = [ "object_store", "reqwest 0.11.27", "serde_json", + "sqlx", "tempfile", "testcontainers 0.20.0", "testcontainers-modules", diff --git a/datafusion_iceberg/Cargo.toml b/datafusion_iceberg/Cargo.toml index a130e535..f3c2f77b 100644 --- a/datafusion_iceberg/Cargo.toml +++ b/datafusion_iceberg/Cargo.toml @@ -34,3 +34,4 @@ testcontainers = "0.20.0" tokio-stream = { version = "0.1.15", features = ["io-util"] } tempfile = "3.10.1" reqwest = "0.11" +sqlx = { version = "0.7.4", features = ["runtime-tokio", "tls-rustls", "any", "sqlite", "postgres", "mysql"], default-features = false } diff --git a/iceberg-sql-catalog/src/lib.rs b/iceberg-sql-catalog/src/lib.rs index a5756047..da90b8e1 100644 --- a/iceberg-sql-catalog/src/lib.rs +++ b/iceberg-sql-catalog/src/lib.rs @@ -31,7 +31,7 @@ use object_store::ObjectStore; use sqlx::{ any::{install_default_drivers, AnyPoolOptions, AnyRow}, pool::PoolOptions, - AnyPool, Row, + AnyPool, Executor, Row, }; use uuid::Uuid; @@ -55,16 +55,17 @@ impl SqlCatalog { ) -> Result { install_default_drivers(); - let mut options = PoolOptions::new(); + let mut pool_options = PoolOptions::new(); if url.starts_with("sqlite") { - options = options.max_connections(1); + pool_options = pool_options.max_connections(1); } - let pool = AnyPoolOptions::connect(options, &url).await?; - - sqlx::query( - "create table if not exists iceberg_tables ( + let pool = AnyPoolOptions::after_connect(pool_options, |connection, _| { + Box::pin(async move { + connection + .execute( + "create table if not exists iceberg_tables ( catalog_name varchar(255) not null, table_namespace varchar(255) not null, table_name varchar(255) not null, @@ -72,22 +73,24 @@ impl SqlCatalog { previous_metadata_location varchar(255), primary key (catalog_name, table_namespace, table_name) );", - ) - .execute(&pool) - .await?; - - sqlx::query( - "create table if not exists iceberg_namespace_properties ( + ) + .await?; + connection + .execute( + "create table if not exists iceberg_namespace_properties ( catalog_name varchar(255) not null, namespace varchar(255) not null, property_key varchar(255), property_value varchar(255), primary key (catalog_name, namespace, property_key) );", - ) - .execute(&pool) - .await - .map_err(Error::from)?; + ) + .await?; + Ok(()) + }) + }) + .connect(&url) + .await?; Ok(SqlCatalog { name: name.to_owned(), @@ -709,10 +712,17 @@ impl SqlCatalogList { pub async fn new(url: &str, object_store: Arc) -> Result { install_default_drivers(); - let pool = AnyPoolOptions::connect(PoolOptions::new().max_connections(1), &url).await?; + let mut pool_options = PoolOptions::new(); + + if url.starts_with("sqlite") { + pool_options = pool_options.max_connections(1); + } - sqlx::query( - "create table if not exists iceberg_tables ( + let pool = AnyPoolOptions::after_connect(pool_options, |connection, _| { + Box::pin(async move { + connection + .execute( + "create table if not exists iceberg_tables ( catalog_name varchar(255) not null, table_namespace varchar(255) not null, table_name varchar(255) not null, @@ -720,22 +730,24 @@ impl SqlCatalogList { previous_metadata_location varchar(255), primary key (catalog_name, table_namespace, table_name) );", - ) - .execute(&pool) - .await?; - - sqlx::query( - "create table if not exists iceberg_namespace_properties ( + ) + .await?; + connection + .execute( + "create table if not exists iceberg_namespace_properties ( catalog_name varchar(255) not null, namespace varchar(255) not null, property_key varchar(255), property_value varchar(255), primary key (catalog_name, namespace, property_key) );", - ) - .execute(&pool) - .await - .map_err(Error::from)?; + ) + .await?; + Ok(()) + }) + }) + .connect(&url) + .await?; Ok(SqlCatalogList { pool, object_store }) }