Skip to content

Commit

Permalink
fix the remote column inferring
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Oct 31, 2024
1 parent 06f641a commit e02e9e2
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 35 deletions.
27 changes: 20 additions & 7 deletions wren-core/core/src/logical_plan/analyze/model_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ use crate::logical_plan::analyze::plan::{
use crate::logical_plan::utils::create_remote_table_source;
use crate::mdl::manifest::Model;
use crate::mdl::utils::quoted;
use crate::mdl::AnalyzedWrenMDL;
use crate::mdl::{AnalyzedWrenMDL, Dataset, SessionStateRef};

/// [ModelGenerationRule] is responsible for generating the model plan node.
pub struct ModelGenerationRule {
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
session_state: SessionStateRef,
}

impl ModelGenerationRule {
pub fn new(mdl: Arc<AnalyzedWrenMDL>) -> Self {
pub fn new(mdl: Arc<AnalyzedWrenMDL>, session_state: SessionStateRef) -> Self {
Self {
analyzed_wren_mdl: mdl,
session_state,
}
}

Expand All @@ -39,7 +41,10 @@ impl ModelGenerationRule {
extension.node.as_any().downcast_ref::<ModelPlanNode>()
{
let source_plan = model_plan.relation_chain.clone().plan(
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
)?;
let result = match source_plan {
Some(plan) => {
Expand Down Expand Up @@ -67,15 +72,17 @@ impl ModelGenerationRule {
.get_model(&model_plan.model_name)
.expect("Model not found"),
);
let dataset = Dataset::Model(Arc::clone(&model));
// support table reference
let table_scan = match &model_plan.original_table_scan {
Some(LogicalPlan::TableScan(original_scan)) => {
LogicalPlanBuilder::scan_with_filters(
TableReference::from(model.table_reference()),
create_remote_table_source(
&model,
&dataset,
&self.analyzed_wren_mdl.wren_mdl(),
),
Arc::clone(&self.session_state),
)?,
None,
original_scan.filters.clone(),
).expect("Failed to create table scan")
Expand All @@ -89,7 +96,10 @@ impl ModelGenerationRule {
None => {
LogicalPlanBuilder::scan(
TableReference::from(model.table_reference()),
create_remote_table_source(&model, &self.analyzed_wren_mdl.wren_mdl()),
create_remote_table_source(
&dataset,
&self.analyzed_wren_mdl.wren_mdl(),
Arc::clone(&self.session_state))?,
None,
).expect("Failed to create table scan")
.project(model_plan.required_exprs.clone())?
Expand All @@ -111,7 +121,10 @@ impl ModelGenerationRule {
.downcast_ref::<CalculationPlanNode>(
) {
let source_plan = calculation_plan.relation_chain.clone().plan(
ModelGenerationRule::new(Arc::clone(&self.analyzed_wren_mdl)),
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
)?;

if let Expr::Alias(alias) = calculation_plan.measures[0].clone() {
Expand Down
34 changes: 10 additions & 24 deletions wren-core/core/src/logical_plan/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use datafusion::arrow::datatypes::{
};
use datafusion::catalog_common::TableReference;
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result;
use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource};
use log::debug;
Expand Down Expand Up @@ -112,29 +111,16 @@ pub fn create_schema(columns: Vec<Arc<Column>>) -> Result<SchemaRef> {
)))
}

pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc<dyn TableSource> {
if let Some(table_provider) = mdl.get_table(model.table_reference()) {
Arc::new(DefaultTableSource::new(table_provider))
} else {
let fields: Vec<Field> = model
.get_physical_columns()
.iter()
.map(|column| {
let column = Arc::clone(column);
let name = if let Some(ref expression) = column.expression {
expression.clone()
} else {
column.name.clone()
};
// TODO: find a way for the remote table to provide the data type
// We don't know the data type of the remote table, so we just mock a Int32 type here
Field::new(name, DataType::Int8, column.not_null)
})
.collect();

let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new()));
Arc::new(LogicalTableSource::new(schema))
}
pub fn create_remote_table_source(
dataset: &Dataset,
mdl: &WrenMDL,
session_state_ref: SessionStateRef,
) -> Result<Arc<dyn TableSource>> {
let schema =
dataset.to_remote_schema(Some(mdl.get_register_tables()), session_state_ref)?;
Ok(Arc::new(LogicalTableSource::new(Arc::new(
schema.as_arrow().clone(),
))))
}

pub fn format_qualified_name(
Expand Down
3 changes: 1 addition & 2 deletions wren-core/core/src/mdl/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::logical_plan::utils::map_data_type;
use crate::mdl::manifest::{Column, Metric, Model};
use crate::mdl::utils::quoted;
use crate::mdl::{RegisterTables, SessionStateRef};
use datafusion::arrow::datatypes::DataType::Utf8;
use datafusion::arrow::datatypes::Field;
use datafusion::common::DFSchema;
use datafusion::common::Result;
Expand Down Expand Up @@ -75,7 +74,7 @@ impl Column {
let columns = Self::collect_columns(expr);
Ok(columns
.into_iter()
.map(|c| Field::new(c.value, Utf8, false))
.map(|c| Field::new(c.value, map_data_type(&self.r#type), false))
.collect())
} else {
Ok(vec![self.to_field()])
Expand Down
41 changes: 39 additions & 2 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ mod test {
use crate::mdl::function::RemoteFunction;
use crate::mdl::manifest::Manifest;
use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL};
use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use datafusion::arrow::array::{
ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray,
};
use datafusion::common::not_impl_err;
use datafusion::common::Result;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -818,11 +820,43 @@ mod test {
Arc::new(AnalyzedWrenMDL::default()),
&[],
sql,
).await?;
)
.await?;
assert_eq!(actual, "SELECT current_date()");
Ok(())
}

/// This test will be failed if the `出道時間` is not inferred as a timestamp column correctly.
#[tokio::test]
async fn test_infer_timestamp_column() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("artist", artist())?;
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("artist")
.table_reference("artist")
.column(ColumnBuilder::new("出道時間", "timestamp").build())
.build(),
)
.build();

let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let sql = r#"select current_date > "出道時間" from wren.test.artist"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&[],
sql,
)
.await?;
assert_eq!(actual,
"SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \
(SELECT artist.\"出道時間\" FROM (SELECT artist.\"出道時間\" AS \"出道時間\" FROM artist) AS artist) AS artist");
Ok(())
}

async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
let ctx = SessionContext::new();
// To roundtrip testing, we should register the mock table for the planned sql.
Expand Down Expand Up @@ -886,10 +920,13 @@ mod test {
Arc::new(StringArray::from_iter_values(["Ina", "Azki", "Kaela"]));
let group: ArrayRef = Arc::new(StringArray::from_iter_values(["EN", "JP", "ID"]));
let subscribe: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300]));
let debut_time: ArrayRef =
Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3]));
RecordBatch::try_from_iter(vec![
("名字", name),
("η΅„εˆ₯", group),
("訂閱數", subscribe),
("出道時間", debut_time),
])
.unwrap()
}
Expand Down

0 comments on commit e02e9e2

Please sign in to comment.