Skip to content

Commit

Permalink
Support 1 or 3 args in generate_series() UDTF
Browse files Browse the repository at this point in the history
  • Loading branch information
UBarney committed Dec 20, 2024
1 parent 9f530dd commit 5803583
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 75 deletions.
161 changes: 94 additions & 67 deletions datafusion/functions-table/src/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
use datafusion_physical_plan::ExecutionPlan;
use parking_lot::RwLock;
use std::fmt;
use std::sync::Arc;

#[derive(Debug, Clone)]
enum GenSeriesArgs {
ContainsNull,
AllNotNullArgs { start: i64, end: i64, step: i64 },
}

/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
#[derive(Debug, Clone)]
struct GenerateSeriesTable {
schema: SchemaRef,
// None if input is Null
start: Option<i64>,
// None if input is Null
end: Option<i64>,
args: GenSeriesArgs,
}

/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
Expand All @@ -46,12 +49,23 @@ struct GenerateSeriesState {
schema: SchemaRef,
start: i64, // Kept for display
end: i64,
step: i64,
batch_size: usize,

/// Tracks current position when generating table
current: i64,
}

impl GenerateSeriesState {
fn reach_end(&self, val: i64) -> bool {
if self.step > 0 {
return val > self.end;
}

val < self.end
}
}

/// Detail to display for 'Explain' plan
impl fmt::Display for GenerateSeriesState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand All @@ -65,19 +79,19 @@ impl fmt::Display for GenerateSeriesState {

impl LazyBatchGenerator for GenerateSeriesState {
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
// Check if we've reached the end
if self.current > self.end {
let mut buf = Vec::with_capacity(self.batch_size);
while buf.len() < self.batch_size && !self.reach_end(self.current) {
buf.push(self.current);
self.current += self.step;
}
let array = Int64Array::from(buf);

if array.is_empty() {
return Ok(None);
}

// Construct batch
let batch_end = (self.current + self.batch_size as i64 - 1).min(self.end);
let array = Int64Array::from_iter_values(self.current..=batch_end);
let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(array)])?;

// Update current position for next batch
self.current = batch_end + 1;

Ok(Some(batch))
}
}
Expand All @@ -104,77 +118,90 @@ impl TableProvider for GenerateSeriesTable {
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batch_size = state.config_options().execution.batch_size;
match (self.start, self.end) {
(Some(start), Some(end)) => {
if start > end {
return plan_err!(
"End value must be greater than or equal to start value"
);
}

Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(GenerateSeriesState {
schema: self.schema.clone(),
start,
end,
current: start,
batch_size,
}))],
)?))
}
_ => {
// Either start or end is None, return a generator that outputs 0 rows
Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(GenerateSeriesState {
schema: self.schema.clone(),
start: 0,
end: 0,
current: 1,
batch_size,
}))],
)?))
}
}

let state = match self.args {
// if args have null, then return 0 row
GenSeriesArgs::ContainsNull => GenerateSeriesState {
schema: self.schema.clone(),
start: 0,
end: 0,
step: 1,
current: 1,
batch_size,
},
GenSeriesArgs::AllNotNullArgs { start, end, step } => GenerateSeriesState {
schema: self.schema.clone(),
start,
end,
step,
current: start,
batch_size,
},
};

Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(state))],
)?))
}
}

#[derive(Debug)]
pub struct GenerateSeriesFunc {}

impl TableFunctionImpl for GenerateSeriesFunc {
// Check input `exprs` type and number. Input validity check (e.g. start <= end)
// will be performed in `TableProvider::scan`
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// TODO: support 1 or 3 arguments following DuckDB:
// <https://duckdb.org/docs/sql/functions/list#generate_series>
if exprs.len() == 3 || exprs.len() == 1 {
return not_impl_err!("generate_series does not support 1 or 3 arguments");
if exprs.is_empty() || exprs.len() > 3 {
return plan_err!("generate_series function requires 1 to 3 arguments");
}

if exprs.len() != 2 {
return plan_err!("generate_series expects 2 arguments");
let mut normalize_args = Vec::new();
for expr in exprs {
match expr {
Expr::Literal(ScalarValue::Null) => {}
Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n),
_ => return plan_err!("First argument must be an integer literal"),
};
}

let start = match &exprs[0] {
Expr::Literal(ScalarValue::Null) => None,
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
_ => return plan_err!("First argument must be an integer literal"),
};

let end = match &exprs[1] {
Expr::Literal(ScalarValue::Null) => None,
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
_ => return plan_err!("Second argument must be an integer literal"),
};

let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
false,
)]));

Ok(Arc::new(GenerateSeriesTable { schema, start, end }))
if normalize_args.len() != exprs.len() {
// contain null
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull,
}));
}

let (start, end, step) = match &normalize_args[..] {
[end] => (0, *end, 1),
[start, end] => (*start, *end, 1),
[start, end, step] => (*start, *end, *step),
_ => {
return plan_err!("generate_series function requires 1 to 3 arguments");
}
};

if start > end && step > 0 {
return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series");
}

if start < end && step < 0 {
return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series");
}

if step == 0 {
return plan_err!("step cannot be zero");
}

Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::AllNotNullArgs { start, end, step },
}))
}
}
63 changes: 55 additions & 8 deletions datafusion/sqllogictest/test_files/table_functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
# under the License.

# Test generate_series table function
query I
SELECT * FROM generate_series(6)
----
0
1
2
3
4
5
6



query I rowsort
SELECT * FROM generate_series(1, 5)
Expand All @@ -39,11 +51,35 @@ SELECT * FROM generate_series(3, 6)
5
6

# #generated_data > batch_size
query I
SELECT count(v1) FROM generate_series(-66666,66666) t1(v1)
----
133333




query I rowsort
SELECT SUM(v1) FROM generate_series(1, 5) t1(v1)
----
15

query I
SELECT * FROM generate_series(6, -1, -2)
----
6
4
2
0

query I
SELECT * FROM generate_series(6, 66, 666)
----
6



# Test generate_series with WHERE clause
query I rowsort
SELECT * FROM generate_series(1, 10) t1(v1) WHERE v1 % 2 = 0
Expand Down Expand Up @@ -93,6 +129,10 @@ ON a.v1 = b.v1 - 1
2 3
3 4

#
# Test generate_series with null arguments
#

query I
SELECT * FROM generate_series(NULL, 5)
----
Expand All @@ -105,6 +145,11 @@ query I
SELECT * FROM generate_series(NULL, NULL)
----

query I
SELECT * FROM generate_series(1, 5, NULL)
----


query TT
EXPLAIN SELECT * FROM generate_series(1, 5)
----
Expand All @@ -115,20 +160,22 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s
# Test generate_series with invalid arguments
#

query error DataFusion error: Error during planning: End value must be greater than or equal to start value
query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
SELECT * FROM generate_series(5, 1)

statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
SELECT * FROM generate_series(1, 5, NULL)
query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series
SELECT * FROM generate_series(-6, 6, -1)

query error DataFusion error: Error during planning: step cannot be zero
SELECT * FROM generate_series(-6, 6, 0)

query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
SELECT * FROM generate_series(6, -6, 1)

statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
SELECT * FROM generate_series(1)

statement error DataFusion error: Error during planning: generate_series expects 2 arguments
statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments
SELECT * FROM generate_series(1, 2, 3, 4)

statement error DataFusion error: Error during planning: Second argument must be an integer literal
SELECT * FROM generate_series(1, '2')

statement error DataFusion error: Error during planning: First argument must be an integer literal
SELECT * FROM generate_series('foo', 'bar')
Expand Down

0 comments on commit 5803583

Please sign in to comment.