diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 5cbb9b4e..132aa1a1 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -30,7 +30,7 @@ impl SeafowlCli { rl.load_history(SEAFOWL_CLI_HISTORY).ok(); loop { - match rl.readline(format!("{}> ", self.ctx.database).as_str()) { + match rl.readline(format!("{}> ", self.ctx.default_catalog).as_str()) { Ok(line) if line.starts_with('\\') => { rl.add_history_entry(line.trim_end())?; let command = line.split_whitespace().collect::>().join(" "); diff --git a/src/config/context.rs b/src/config/context.rs index cfbd3796..ec1f7a68 100644 --- a/src/config/context.rs +++ b/src/config/context.rs @@ -208,7 +208,8 @@ pub async fn build_context(cfg: &schema::SeafowlConfig) -> Result( - &self, + &'a self, name: impl Into>, details: CreateDeltaTableDetails, ) -> Result> { - let table_ref: TableReference = name.into(); - let resolved_ref = table_ref.resolve(&self.database, DEFAULT_SCHEMA); + let resolved_ref = self.resolve_table_ref(name); let schema_name = resolved_ref.schema.clone(); let table_name = resolved_ref.table.clone(); let _ = self .metastore .schemas - .get(&self.database, &schema_name) + .get(&self.default_catalog, &schema_name) .await?; // NB: there's also a uuid generated below for table's `DeltaTableMetaData::id`, so it would @@ -353,7 +351,7 @@ impl SeafowlContext { self.metastore .tables .create( - &self.database, + &self.default_catalog, &schema_name, &table_name, TableProvider::schema(&table).as_ref(), diff --git a/src/context/logical.rs b/src/context/logical.rs index c32ccdbc..49318f46 100644 --- a/src/context/logical.rs +++ b/src/context/logical.rs @@ -18,8 +18,9 @@ use datafusion::optimizer::analyzer::Analyzer; use datafusion::optimizer::optimizer::Optimizer; use datafusion::optimizer::simplify_expressions::SimplifyExpressions; use datafusion::optimizer::{OptimizerContext, OptimizerRule}; +use datafusion::prelude::SessionContext; use datafusion::sql::parser::{CopyToSource, CopyToStatement}; -use datafusion::{prelude::SessionContext, sql::TableReference}; +use datafusion_common::TableReference; use datafusion_expr::logical_plan::{Extension, LogicalPlan}; use deltalake::DeltaTable; use itertools::Itertools; @@ -325,8 +326,10 @@ impl SeafowlContext { // Should become obsolete once `sqlparser-rs` introduces support for some form of the `AS OF` // clause: https://en.wikipedia.org/wiki/SQL:2011. async fn rewrite_time_travel_query(&self, q: &mut Query) -> Result { - let mut version_processor = - TableVersionProcessor::new(self.database.clone(), DEFAULT_SCHEMA.to_string()); + let mut version_processor = TableVersionProcessor::new( + self.default_catalog.clone(), + DEFAULT_SCHEMA.to_string(), + ); q.visit(&mut version_processor); if version_processor.table_versions.is_empty() { @@ -348,7 +351,7 @@ impl SeafowlContext { let full_table_name = table.to_string(); let mut resolved_ref = TableReference::from(full_table_name.as_str()) - .resolve(&self.database, DEFAULT_SCHEMA); + .resolve(&self.default_catalog, &self.default_schema); // We only support datetime DeltaTable version specification for start let table_uuid = self.get_table_uuid(resolved_ref.clone()).await?; diff --git a/src/context/mod.rs b/src/context/mod.rs index 19f50d8e..c326cfaf 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -12,7 +12,7 @@ use crate::wasm_udf::wasm::create_udf_from_wasm; use base64::{engine::general_purpose::STANDARD, Engine}; pub use datafusion::error::{DataFusionError as Error, Result}; use datafusion::{error::DataFusionError, prelude::SessionContext, sql::TableReference}; -use datafusion_common::OwnedTableReference; +use datafusion_common::{OwnedTableReference, ResolvedTableReference}; use deltalake::DeltaTable; use object_store::path::Path; use std::sync::Arc; @@ -24,18 +24,19 @@ pub struct SeafowlContext { pub inner: SessionContext, pub metastore: Arc, pub internal_object_store: Arc, - pub database: String, + pub default_catalog: String, + pub default_schema: String, pub max_partition_size: u32, } impl SeafowlContext { - // Create a new `SeafowlContext` with a new inner context scoped to a different default DB - pub fn scope_to_database(&self, name: String) -> Arc { + // Create a new `SeafowlContext` with a new inner context scoped to a different default catalog/schema + pub fn scope_to(&self, catalog: String, schema: String) -> Arc { // Swap the default catalog in the new internal context's session config let session_config = self .inner() .copied_config() - .with_default_catalog_and_schema(name.clone(), DEFAULT_SCHEMA); + .with_default_catalog_and_schema(&catalog, &schema); let state = build_state_with_table_factories(session_config, self.inner().runtime_env()); @@ -44,11 +45,20 @@ impl SeafowlContext { inner: SessionContext::new_with_state(state), metastore: self.metastore.clone(), internal_object_store: self.internal_object_store.clone(), - database: name, + default_catalog: catalog, + default_schema: schema, max_partition_size: self.max_partition_size, }) } + pub fn scope_to_catalog(&self, catalog: String) -> Arc { + self.scope_to(catalog, DEFAULT_SCHEMA.to_string()) + } + + pub fn scope_to_schema(&self, schema: String) -> Arc { + self.scope_to(self.default_catalog.clone(), schema) + } + pub fn inner(&self) -> &SessionContext { &self.inner } @@ -67,18 +77,28 @@ impl SeafowlContext { // This does incur a latency cost to every query. self.inner.register_catalog( - &self.database, - Arc::new(self.metastore.build_catalog(&self.database).await?), + &self.default_catalog, + Arc::new(self.metastore.build_catalog(&self.default_catalog).await?), ); // Register all functions in the database self.metastore - .build_functions(&self.database) + .build_functions(&self.default_catalog) .await? .iter() .try_for_each(|f| self.register_function(&f.name, &f.details)) } + // Taken from DF SessionState where's it's private + pub fn resolve_table_ref<'a>( + &'a self, + table_ref: impl Into>, + ) -> ResolvedTableReference<'a> { + table_ref + .into() + .resolve(&self.default_catalog, &self.default_schema) + } + // Check that the TableReference doesn't have a database/schema in it. // We create all external tables in the staging schema (backed by DataFusion's // in-memory schema provider) instead. @@ -93,15 +113,15 @@ impl SeafowlContext { // This means that any potential catalog/schema references get condensed into the name, so // we have to unravel that name here again, and then resolve it properly. let reference = TableReference::from(name.to_string()); - let resolved_reference = reference.resolve(&self.database, STAGING_SCHEMA); + let resolved_reference = reference.resolve(&self.default_catalog, STAGING_SCHEMA); - if resolved_reference.catalog != self.database + if resolved_reference.catalog != self.default_catalog || resolved_reference.schema != STAGING_SCHEMA { return Err(DataFusionError::Plan(format!( "Can only create external tables in the staging schema. Omit the schema/database altogether or use {}.{}.{}", - &self.database, STAGING_SCHEMA, resolved_reference.table + &self.default_catalog, STAGING_SCHEMA, resolved_reference.table ))); } @@ -221,7 +241,7 @@ pub mod test_utils { // place on another node context.metastore.catalogs.create("testdb").await.unwrap(); - let context = context.scope_to_database("testdb".to_string()); + let context = context.scope_to_catalog("testdb".to_string()); // Create new non-default collection context.plan_query("CREATE SCHEMA testcol").await.unwrap(); diff --git a/src/context/physical.rs b/src/context/physical.rs index 35478de2..748c6bdd 100644 --- a/src/context/physical.rs +++ b/src/context/physical.rs @@ -177,7 +177,7 @@ impl SeafowlContext { // Create a schema and register it self.metastore .schemas - .create(&self.database, schema_name) + .create(&self.default_catalog, schema_name) .await?; Ok(make_dummy_exec()) } @@ -518,8 +518,7 @@ impl SeafowlContext { if_exists: _, schema: _, })) => { - let table_ref = TableReference::from(name); - let resolved_ref = table_ref.resolve(&self.database, DEFAULT_SCHEMA); + let resolved_ref = self.resolve_table_ref(name); if resolved_ref.schema == STAGING_SCHEMA { // Dropping a staging table is a in-memory only op @@ -549,7 +548,7 @@ impl SeafowlContext { let schema_name = name.schema_name(); if let SchemaReference::Full { catalog, .. } = name - && catalog != &self.database + && catalog != &self.default_catalog { return Err(DataFusionError::Execution( "Cannot delete schemas in other catalogs".to_string(), @@ -558,7 +557,7 @@ impl SeafowlContext { let schema = match self .inner - .catalog(&self.database) + .catalog(&self.default_catalog) .expect("Current catalog exists") .schema(schema_name) { @@ -577,7 +576,7 @@ impl SeafowlContext { // Delete each table sequentially for table_name in schema.table_names() { let table_ref = ResolvedTableReference { - catalog: Cow::from(&self.database), + catalog: Cow::from(&self.default_catalog), schema: Cow::from(schema_name), table: Cow::from(table_name), }; @@ -591,7 +590,7 @@ impl SeafowlContext { self.metastore .schemas - .delete(&self.database, schema_name) + .delete(&self.default_catalog, schema_name) .await?; Ok(make_dummy_exec()) @@ -642,7 +641,7 @@ impl SeafowlContext { // Persist the function in the metadata storage self.metastore .functions - .create(&self.database, name, *or_replace, details) + .create(&self.default_catalog, name, *or_replace, details) .await?; Ok(make_dummy_exec()) @@ -654,7 +653,7 @@ impl SeafowlContext { }) => { self.metastore .functions - .delete(&self.database, *if_exists, func_names) + .delete(&self.default_catalog, *if_exists, func_names) .await?; Ok(make_dummy_exec()) } @@ -664,10 +663,8 @@ impl SeafowlContext { .. }) => { // Resolve new table reference - let new_table_ref = TableReference::from(new_name.as_str()); - let resolved_new_ref = - new_table_ref.resolve(&self.database, DEFAULT_SCHEMA); - if resolved_new_ref.catalog != self.database { + let resolved_new_ref = self.resolve_table_ref(new_name); + if resolved_new_ref.catalog != self.default_catalog { return Err(Error::Plan( "Changing the table's database is not supported!" .to_string(), @@ -675,9 +672,7 @@ impl SeafowlContext { } // Resolve old table reference - let old_table_ref = TableReference::from(old_name.as_str()); - let resolved_old_ref = - old_table_ref.resolve(&self.database, DEFAULT_SCHEMA); + let resolved_old_ref = self.resolve_table_ref(old_name); // Finally update our catalog entry self.metastore @@ -701,9 +696,7 @@ impl SeafowlContext { if database.is_some() { gc_databases(self, database.clone()).await; } else if let Some(table_name) = table_name { - let table_ref = TableReference::from(table_name.as_str()); - let resolved_ref = - table_ref.resolve(&self.database, DEFAULT_SCHEMA); + let resolved_ref = self.resolve_table_ref(table_name); if let Ok(mut delta_table) = self.try_get_delta_table(resolved_ref.clone()).await @@ -829,8 +822,10 @@ impl SeafowlContext { // Check whether table already exists and ensure that the schema exists let table_exists = match self .inner - .catalog(&self.database) - .ok_or_else(|| Error::Plan(format!("Database {} not found!", self.database)))? + .catalog(&self.default_catalog) + .ok_or_else(|| { + Error::Plan(format!("Database {} not found!", self.default_catalog)) + })? .schema(&schema_name) { Some(_) => { @@ -847,7 +842,7 @@ impl SeafowlContext { // Schema doesn't exist; create one first, and then reload to pick it up self.metastore .schemas - .create(&self.database, &schema_name) + .create(&self.default_catalog, &schema_name) .await?; self.reload_schema().await?; false @@ -888,7 +883,7 @@ impl SeafowlContext { let plan = source.scan(&self.inner.state(), None, &[], None).await?; let table_ref = TableReference::Full { - catalog: Cow::from(&self.database), + catalog: Cow::from(&self.default_catalog), schema: Cow::from(schema_name), table: Cow::from(table_name), }; diff --git a/src/frontend/flight/handler.rs b/src/frontend/flight/handler.rs index 3dd36c47..f70dab8d 100644 --- a/src/frontend/flight/handler.rs +++ b/src/frontend/flight/handler.rs @@ -9,6 +9,7 @@ use datafusion_common::DataFusionError; use lazy_static::lazy_static; use std::sync::Arc; use tokio::sync::Mutex; +use tonic::metadata::MetadataMap; lazy_static! { pub static ref SEAFOWL_SQL_DATA: SqlInfoData = { @@ -44,9 +45,23 @@ impl SeafowlFlightHandler { &self, query: &str, query_id: String, + metadata: &MetadataMap, ) -> Result { - let plan = self.context.plan_query(query).await?; - let batch_stream = self.context.execute_stream(plan).await?; + let ctx = if let Some(search_path) = metadata.get("search-path") { + self.context.scope_to_schema( + search_path + .to_str() + .map_err(|e| DataFusionError::Execution(format!( + "Couldn't parse search path from header value {search_path:?}: {e}" + )))? + .to_string(), + ) + } else { + self.context.clone() + }; + + let plan = ctx.plan_query(query).await?; + let batch_stream = ctx.execute_stream(plan).await?; let schema = batch_stream.schema(); self.results.insert(query_id, Mutex::new(batch_stream)); diff --git a/src/frontend/flight/sql.rs b/src/frontend/flight/sql.rs index fd01752c..b5b19e7c 100644 --- a/src/frontend/flight/sql.rs +++ b/src/frontend/flight/sql.rs @@ -81,7 +81,7 @@ impl FlightSqlService for SeafowlFlightHandler { let query_id = Uuid::new_v4().to_string(); let schema = self - .query_to_stream(&query.query, query_id.clone()) + .query_to_stream(&query.query, query_id.clone(), request.metadata()) .await .map_err(|e| Status::internal(e.to_string()))?; diff --git a/src/frontend/http.rs b/src/frontend/http.rs index 5ccd397d..89cb9fed 100644 --- a/src/frontend/http.rs +++ b/src/frontend/http.rs @@ -159,8 +159,8 @@ pub async fn uncached_read_write_query( // If a specific DB name was used as a parameter in the route, scope the context to it, // effectively making it the default DB for the duration of the session. - if database_name != context.database { - context = context.scope_to_database(database_name); + if database_name != context.default_catalog { + context = context.scope_to_catalog(database_name); } let statements = context.parse_query(&query).await?; @@ -324,8 +324,8 @@ pub async fn cached_read_query( // If a specific DB name was used as a parameter in the route, scope the context to it, // effectively making it the default DB for the duration of the session. - if database_name != context.database { - context = context.scope_to_database(database_name); + if database_name != context.default_catalog { + context = context.scope_to_catalog(database_name); } // Plan the query @@ -382,8 +382,8 @@ pub async fn upload( return Err(ApiError::WriteForbidden); }; - if database_name != context.database { - context = context.scope_to_database(database_name.clone()); + if database_name != context.default_catalog { + context = context.scope_to_catalog(database_name.clone()); } let mut has_header = true; @@ -661,7 +661,7 @@ pub mod tests { .await .unwrap(); - context = context.scope_to_database(db_name.to_string()); + context = context.scope_to_catalog(db_name.to_string()); } context @@ -676,7 +676,7 @@ pub mod tests { if new_db.is_some() { // Re-scope to the original DB - return context.scope_to_database(DEFAULT_DB.to_string()); + return context.scope_to_catalog(DEFAULT_DB.to_string()); } context @@ -688,7 +688,7 @@ pub mod tests { let mut context = in_memory_context_with_single_table(new_db).await; if let Some(db_name) = new_db { - context = context.scope_to_database(db_name.to_string()); + context = context.scope_to_catalog(db_name.to_string()); } context @@ -698,7 +698,7 @@ pub mod tests { if new_db.is_some() { // Re-scope to the original DB - return context.scope_to_database(DEFAULT_DB.to_string()); + return context.scope_to_catalog(DEFAULT_DB.to_string()); } context diff --git a/tests/flight/client.rs b/tests/flight/client.rs index f642abaa..e0fdf124 100644 --- a/tests/flight/client.rs +++ b/tests/flight/client.rs @@ -1,30 +1,5 @@ use crate::flight::*; -async fn get_flight_batches( - client: &mut FlightClient, - query: String, -) -> Result> { - let cmd = CommandStatementQuery { - query, - transaction_id: None, - }; - let request = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); - let response = client.get_flight_info(request).await?; - - // Get the returned ticket - let ticket = response.endpoint[0] - .ticket - .clone() - .expect("expected ticket"); - - // Retrieve the corresponding Flight stream and collect into batches - let flight_stream = client.do_get(ticket).await?; - - let batches = flight_stream.try_collect().await?; - - Ok(batches) -} - #[tokio::test] async fn test_basic_queries() -> Result<()> { let (context, addr, flight) = start_flight_server().await; diff --git a/tests/flight/mod.rs b/tests/flight/mod.rs index 820cd606..2bcd3fe0 100644 --- a/tests/flight/mod.rs +++ b/tests/flight/mod.rs @@ -15,9 +15,11 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use tokio::net::TcpListener; +use tonic::metadata::MetadataValue; use tonic::transport::Channel; mod client; +mod search_path; async fn start_flight_server() -> ( Arc, @@ -67,3 +69,28 @@ async fn create_flight_client(addr: SocketAddr) -> FlightClient { FlightClient::new(channel) } + +async fn get_flight_batches( + client: &mut FlightClient, + query: String, +) -> Result> { + let cmd = CommandStatementQuery { + query, + transaction_id: None, + }; + let request = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let response = client.get_flight_info(request).await?; + + // Get the returned ticket + let ticket = response.endpoint[0] + .ticket + .clone() + .expect("expected ticket"); + + // Retrieve the corresponding Flight stream and collect into batches + let flight_stream = client.do_get(ticket).await?; + + let batches = flight_stream.try_collect().await?; + + Ok(batches) +} diff --git a/tests/flight/search_path.rs b/tests/flight/search_path.rs new file mode 100644 index 00000000..21d99f7a --- /dev/null +++ b/tests/flight/search_path.rs @@ -0,0 +1,52 @@ +use crate::flight::*; + +#[tokio::test] +async fn test_default_schema_override( +) -> std::result::Result<(), Box> { + let (context, addr, flight) = start_flight_server().await; + + context.plan_query("CREATE SCHEMA some_schema").await?; + create_table_and_insert(context.as_ref(), "some_schema.flight_table").await; + tokio::task::spawn(flight); + + let mut client = create_flight_client(addr).await; + + // Trying to run the query without the search_path set will error out + let err = get_flight_batches(&mut client, "SELECT * FROM flight_table".to_string()) + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("table 'default.public.flight_table' not found")); + + // Now set the search_path header and re-run the query + client + .metadata_mut() + .insert("search-path", MetadataValue::from_static("some_schema")); + + let results = + get_flight_batches(&mut client, "SELECT * FROM flight_table".to_string()).await?; + + let expected = [ + "+---------------------+------------+------------------+-----------------+----------------+", + "| some_time | some_value | some_other_value | some_bool_value | some_int_value |", + "+---------------------+------------+------------------+-----------------+----------------+", + "| 2022-01-01T20:01:01 | 42.0 | 1.0000000000 | | 1111 |", + "| 2022-01-01T20:02:02 | 43.0 | 1.0000000000 | | 2222 |", + "| 2022-01-01T20:03:03 | 44.0 | 1.0000000000 | | 3333 |", + "+---------------------+------------+------------------+-----------------+----------------+", + ]; + + assert_batches_eq!(expected, &results); + + // Finally, client needs to remove the header explicitly to avoid default schema override + client.metadata_mut().remove("search-path"); + let err = get_flight_batches(&mut client, "SELECT * FROM flight_table".to_string()) + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("table 'default.public.flight_table' not found")); + + Ok(()) +} diff --git a/tests/http/upload.rs b/tests/http/upload.rs index f62d75b7..60a75424 100644 --- a/tests/http/upload.rs +++ b/tests/http/upload.rs @@ -113,7 +113,7 @@ async fn test_upload_base( // Verify the newly created table contents if let Some(db_name) = db_prefix { - context = context.scope_to_database(db_name.to_string()); + context = context.scope_to_catalog(db_name.to_string()); } let plan = context .plan_query(format!("SELECT * FROM test_upload.{table_name}").as_str())