Skip to content

Commit

Permalink
feat: support customizing column default values for inserting (#8283)
Browse files Browse the repository at this point in the history
* parse column default values

* fix clippy

* Impl for memroy table

* Add tests

* Add test

* Use plan_datafusion_err

* Add comment

* Update datafusion/sql/src/planner.rs

Co-authored-by: comphead <[email protected]>

* Fix ci

---------

Co-authored-by: comphead <[email protected]>
  • Loading branch information
jonahgao and comphead authored Nov 22, 2023
1 parent b46b7c0 commit 3dbda1e
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 7 deletions.
4 changes: 4 additions & 0 deletions datafusion/core/src/datasource/default_table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ impl TableSource for DefaultTableSource {
fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
self.table_provider.get_logical_plan()
}

fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.table_provider.get_column_default(column)
}
}

/// Wrap TableProvider in TableSource
Expand Down
16 changes: 16 additions & 0 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use datafusion_physical_plan::metrics::MetricsSet;
use futures::StreamExt;
use hashbrown::HashMap;
use log::debug;
use std::any::Any;
use std::fmt::{self, Debug};
Expand Down Expand Up @@ -56,6 +57,7 @@ pub struct MemTable {
schema: SchemaRef,
pub(crate) batches: Vec<PartitionData>,
constraints: Constraints,
column_defaults: HashMap<String, Expr>,
}

impl MemTable {
Expand All @@ -79,6 +81,7 @@ impl MemTable {
.map(|e| Arc::new(RwLock::new(e)))
.collect::<Vec<_>>(),
constraints: Constraints::empty(),
column_defaults: HashMap::new(),
})
}

Expand All @@ -88,6 +91,15 @@ impl MemTable {
self
}

/// Assign column defaults
pub fn with_column_defaults(
mut self,
column_defaults: HashMap<String, Expr>,
) -> Self {
self.column_defaults = column_defaults;
self
}

/// Create a mem table by reading from another data source
pub async fn load(
t: Arc<dyn TableProvider>,
Expand Down Expand Up @@ -228,6 +240,10 @@ impl TableProvider for MemTable {
None,
)))
}

fn get_column_default(&self, column: &str) -> Option<&Expr> {
self.column_defaults.get(column)
}
}

/// Implements for writing to a [`MemTable`]
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/src/datasource/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ pub trait TableProvider: Sync + Send {
None
}

/// Get the default value for a column, if available.
fn get_column_default(&self, _column: &str) -> Option<&Expr> {
None
}

/// Create an [`ExecutionPlan`] for scanning the table with optionally
/// specified `projection`, `filter` and `limit`, described below.
///
Expand Down
14 changes: 11 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ impl SessionContext {
if_not_exists,
or_replace,
constraints,
column_defaults,
} = cmd;

let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone());
Expand All @@ -542,7 +543,12 @@ impl SessionContext {
let physical = DataFrame::new(self.state(), input);

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
let table = Arc::new(
// pass constraints and column defaults to the mem table.
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);

self.register_table(&name, table)?;
self.return_empty_dataframe()
Expand All @@ -557,8 +563,10 @@ impl SessionContext {

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(
// pass constraints to the mem table.
MemTable::try_new(schema, batches)?.with_constraints(constraints),
// pass constraints and column defaults to the mem table.
MemTable::try_new(schema, batches)?
.with_constraints(constraints)
.with_column_defaults(column_defaults.into_iter().collect()),
);

self.register_table(&name, table)?;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/logical_plan/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct CreateMemoryTable {
pub if_not_exists: bool,
/// Option to replace table content if table already exists
pub or_replace: bool,
/// Default values for columns
pub column_defaults: Vec<(String, Expr)>,
}

/// Creates a view.
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ impl LogicalPlan {
name,
if_not_exists,
or_replace,
column_defaults,
..
})) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(
CreateMemoryTable {
Expand All @@ -819,6 +820,7 @@ impl LogicalPlan {
name: name.clone(),
if_not_exists: *if_not_exists,
or_replace: *or_replace,
column_defaults: column_defaults.clone(),
},
))),
LogicalPlan::Ddl(DdlStatement::CreateView(CreateView {
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,9 @@ pub trait TableSource: Sync + Send {
fn get_logical_plan(&self) -> Option<&LogicalPlan> {
None
}

/// Get the default value for a column, if available.
fn get_column_default(&self, _column: &str) -> Option<&Expr> {
None
}
}
41 changes: 39 additions & 2 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use std::sync::Arc;
use std::vec;

use arrow_schema::*;
use datafusion_common::field_not_found;
use datafusion_common::internal_err;
use datafusion_common::{
field_not_found, internal_err, plan_datafusion_err, SchemaError,
};
use datafusion_expr::WindowUDF;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
Expand Down Expand Up @@ -230,6 +231,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Schema::new(fields))
}

/// Returns a vector of (column_name, default_expr) pairs
pub(super) fn build_column_defaults(
&self,
columns: &Vec<SQLColumnDef>,
planner_context: &mut PlannerContext,
) -> Result<Vec<(String, Expr)>> {
let mut column_defaults = vec![];
// Default expressions are restricted, column references are not allowed
let empty_schema = DFSchema::empty();
let error_desc = |e: DataFusionError| match e {
DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => {
plan_datafusion_err!(
"Column reference is not allowed in the DEFAULT expression : {}",
e
)
}
_ => e,
};

for column in columns {
if let Some(default_sql_expr) =
column.options.iter().find_map(|o| match &o.option {
ColumnOption::Default(expr) => Some(expr),
_ => None,
})
{
let default_expr = self
.sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context)
.map_err(error_desc)?;
column_defaults
.push((self.normalizer.normalize(column.name.clone()), default_expr));
}
}
Ok(column_defaults)
}

/// Apply the given TableAlias to the input plan
pub(crate) fn apply_table_alias(
&self,
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
input: Arc::new(plan),
if_not_exists: false,
or_replace: false,
column_defaults: vec![],
}))
}
_ => plan,
Expand Down
15 changes: 13 additions & 2 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut all_constraints = constraints;
let inline_constraints = calc_inline_constraints_from_columns(&columns);
all_constraints.extend(inline_constraints);
// Build column default values
let column_defaults =
self.build_column_defaults(&columns, planner_context)?;
match query {
Some(query) => {
let plan = self.query_to_plan(*query, planner_context)?;
Expand Down Expand Up @@ -250,6 +253,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
input: Arc::new(plan),
if_not_exists,
or_replace,
column_defaults,
},
)))
}
Expand All @@ -272,6 +276,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
input: Arc::new(plan),
if_not_exists,
or_replace,
column_defaults,
},
)))
}
Expand Down Expand Up @@ -1170,8 +1175,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
datafusion_expr::Expr::Column(source_field.qualified_column())
.cast_to(target_field.data_type(), source.schema())?
}
// Fill the default value for the column, currently only supports NULL.
None => datafusion_expr::Expr::Literal(ScalarValue::Null)
// The value is not specified. Fill in the default value for the column.
None => table_source
.get_column_default(target_field.name())
.cloned()
.unwrap_or_else(|| {
// If there is no default for the column, then the default is NULL
datafusion_expr::Expr::Literal(ScalarValue::Null)
})
.cast_to(target_field.data_type(), &DFSchema::empty())?,
};
Ok(expr.alias(target_field.name()))
Expand Down
62 changes: 62 additions & 0 deletions datafusion/sqllogictest/test_files/insert.slt
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,65 @@ insert into bad_new_empty_table values (1);

statement ok
drop table bad_new_empty_table;


### Test for specifying column's default value

statement ok
create table test_column_defaults(
a int,
b int not null default null,
c int default 100*2+300,
d text default lower('DEFAULT_TEXT'),
e timestamp default now()
)

query IIITP
insert into test_column_defaults values(1, 10, 100, 'ABC', now())
----
1

statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable
insert into test_column_defaults(a) values(2)

query IIITP
insert into test_column_defaults(b) values(20)
----
1

query IIIT rowsort
select a,b,c,d from test_column_defaults
----
1 10 100 ABC
NULL 20 500 default_text

statement ok
drop table test_column_defaults


# test create table as
statement ok
create table test_column_defaults(
a int,
b int not null default null,
c int default 100*2+300,
d text default lower('DEFAULT_TEXT'),
e timestamp default now()
) as values(1, 10, 100, 'ABC', now())

query IIITP
insert into test_column_defaults(b) values(20)
----
1

query IIIT rowsort
select a,b,c,d from test_column_defaults
----
1 10 100 ABC
NULL 20 500 default_text

statement ok
drop table test_column_defaults

statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a.
create table test_column_defaults(a int, b int default a+1)

0 comments on commit 3dbda1e

Please sign in to comment.