diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 6a6291d45..e53c99e6c 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -15,17 +15,19 @@ use crate::logical_plan::analyze::plan::{ use crate::logical_plan::utils::create_remote_table_source; use crate::mdl::manifest::Model; use crate::mdl::utils::quoted; -use crate::mdl::AnalyzedWrenMDL; +use crate::mdl::{AnalyzedWrenMDL, Dataset, SessionStateRef}; /// [ModelGenerationRule] is responsible for generating the model plan node. pub struct ModelGenerationRule { analyzed_wren_mdl: Arc, + session_state: SessionStateRef, } impl ModelGenerationRule { - pub fn new(mdl: Arc) -> Self { + pub fn new(mdl: Arc, session_state: SessionStateRef) -> Self { Self { analyzed_wren_mdl: mdl, + session_state, } } @@ -39,7 +41,10 @@ impl ModelGenerationRule { extension.node.as_any().downcast_ref::() { let source_plan = model_plan.relation_chain.clone().plan( - ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + ModelGenerationRule::new( + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + ), )?; let result = match source_plan { Some(plan) => { @@ -67,15 +72,17 @@ impl ModelGenerationRule { .get_model(&model_plan.model_name) .expect("Model not found"), ); + let dataset = Dataset::Model(Arc::clone(&model)); // support table reference let table_scan = match &model_plan.original_table_scan { Some(LogicalPlan::TableScan(original_scan)) => { LogicalPlanBuilder::scan_with_filters( TableReference::from(model.table_reference()), create_remote_table_source( - &model, + &dataset, &self.analyzed_wren_mdl.wren_mdl(), - ), + Arc::clone(&self.session_state), + )?, None, original_scan.filters.clone(), ).expect("Failed to create table scan") @@ -89,7 +96,10 @@ impl ModelGenerationRule { None => { LogicalPlanBuilder::scan( TableReference::from(model.table_reference()), - create_remote_table_source(&model, &self.analyzed_wren_mdl.wren_mdl()), + create_remote_table_source( + &dataset, + &self.analyzed_wren_mdl.wren_mdl(), + Arc::clone(&self.session_state))?, None, ).expect("Failed to create table scan") .project(model_plan.required_exprs.clone())? @@ -111,7 +121,10 @@ impl ModelGenerationRule { .downcast_ref::( ) { let source_plan = calculation_plan.relation_chain.clone().plan( - ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)), + ModelGenerationRule::new( + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + ), )?; if let Expr::Alias(alias) = calculation_plan.measures[0].clone() { diff --git a/wren-core/core/src/logical_plan/utils.rs b/wren-core/core/src/logical_plan/utils.rs index 667e737e6..1554666f8 100644 --- a/wren-core/core/src/logical_plan/utils.rs +++ b/wren-core/core/src/logical_plan/utils.rs @@ -3,7 +3,6 @@ use datafusion::arrow::datatypes::{ }; use datafusion::catalog_common::TableReference; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion::datasource::DefaultTableSource; use datafusion::error::Result; use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource}; use log::debug; @@ -112,29 +111,16 @@ pub fn create_schema(columns: Vec>) -> Result { ))) } -pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc { - if let Some(table_provider) = mdl.get_table(model.table_reference()) { - Arc::new(DefaultTableSource::new(table_provider)) - } else { - let fields: Vec = model - .get_physical_columns() - .iter() - .map(|column| { - let column = Arc::clone(column); - let name = if let Some(ref expression) = column.expression { - expression.clone() - } else { - column.name.clone() - }; - // TODO: find a way for the remote table to provide the data type - // We don't know the data type of the remote table, so we just mock a Int32 type here - Field::new(name, DataType::Int8, column.not_null) - }) - .collect(); - - let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); - Arc::new(LogicalTableSource::new(schema)) - } +pub fn create_remote_table_source( + dataset: &Dataset, + mdl: &WrenMDL, + session_state_ref: SessionStateRef, +) -> Result> { + let schema = + dataset.to_remote_schema(Some(mdl.get_register_tables()), session_state_ref)?; + Ok(Arc::new(LogicalTableSource::new(Arc::new( + schema.as_arrow().clone(), + )))) } pub fn format_qualified_name( diff --git a/wren-core/core/src/mdl/dataset.rs b/wren-core/core/src/mdl/dataset.rs index a46b9d64c..6a0eee477 100644 --- a/wren-core/core/src/mdl/dataset.rs +++ b/wren-core/core/src/mdl/dataset.rs @@ -2,7 +2,6 @@ use crate::logical_plan::utils::map_data_type; use crate::mdl::manifest::{Column, Metric, Model}; use crate::mdl::utils::quoted; use crate::mdl::{RegisterTables, SessionStateRef}; -use datafusion::arrow::datatypes::DataType::Utf8; use datafusion::arrow::datatypes::Field; use datafusion::common::DFSchema; use datafusion::common::Result; @@ -75,7 +74,7 @@ impl Column { let columns = Self::collect_columns(expr); Ok(columns .into_iter() - .map(|c| Field::new(c.value, Utf8, false)) + .map(|c| Field::new(c.value, map_data_type(&self.r#type), false)) .collect()) } else { Ok(vec![self.to_field()]) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 91334db40..d768caa81 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -434,7 +434,9 @@ mod test { use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::Manifest; use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL}; - use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray}; + use datafusion::arrow::array::{ + ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray, + }; use datafusion::common::not_impl_err; use datafusion::common::Result; use datafusion::prelude::SessionContext; @@ -818,11 +820,43 @@ mod test { Arc::new(AnalyzedWrenMDL::default()), &[], sql, - ).await?; + ) + .await?; assert_eq!(actual, "SELECT current_date()"); Ok(()) } + /// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly. + #[tokio::test] + async fn test_infer_timestamp_column() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("artist", artist())?; + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("artist") + .table_reference("artist") + .column(ColumnBuilder::new("出道時間", "timestamp").build()) + .build(), + ) + .build(); + + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = r#"select current_date > "出道時間" from wren.test.artist"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \ + (SELECT artist.\"出道時間\" FROM (SELECT artist.\"出道時間\" AS \"出道時間\" FROM artist) AS artist) AS artist"); + Ok(()) + } + async fn assert_sql_valid_executable(sql: &str) -> Result<()> { let ctx = SessionContext::new(); // To roundtrip testing, we should register the mock table for the planned sql. @@ -886,10 +920,13 @@ mod test { Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"])); let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"])); let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300])); + let debut_time: ArrayRef = + Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3])); RecordBatch::try_from_iter(vec![ ("名字", name), ("組別", group), ("訂閱數", subscribe), + ("出道時間", debut_time), ]) .unwrap() }