Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions datafusion/core/benches/sql_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,28 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
// the actual ops here are largely unimportant as they are just a sample
// of ops that could occur on a dataframe
df = df
.with_column(&c_name, cast(c.clone(), DataType::Utf8))
.with_column(
&c_name,
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Utf8, true)),
),
)
.unwrap()
.with_column(
&c_name,
when(
cast(c.clone(), DataType::Int32).gt(lit(135)),
cast(
cast(c.clone(), DataType::Int32) - lit(i + 3),
DataType::Utf8,
c.clone(),
Arc::new(Field::new(&c_name, DataType::Int32, true)),
)
.gt(lit(135)),
cast(
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Int32, true)),
) - lit(i + 3),
Arc::new(Field::new(&c_name, DataType::Utf8, true)),
),
)
.otherwise(c.clone())
Expand All @@ -122,15 +135,25 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
&c_name,
when(
c.clone().is_not_null().and(
cast(c.clone(), DataType::Int32)
.between(lit(120), lit(130)),
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Int32, true)),
)
.between(lit(120), lit(130)),
),
Literal(ScalarValue::Utf8(None), None),
)
.otherwise(
when(
c.clone().is_not_null().and(regexp_like(
cast(c.clone(), DataType::Utf8View),
cast(
c.clone(),
Arc::new(Field::new(
&c_name,
DataType::Utf8View,
true,
)),
),
lit("[0-9]*"),
None,
)),
Expand All @@ -146,10 +169,16 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
&c_name,
when(
c.clone().is_not_null().and(
cast(c.clone(), DataType::Int32)
.between(lit(90), lit(100)),
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Int32, true)),
)
.between(lit(90), lit(100)),
),
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Utf8View, true)),
),
cast(c.clone(), DataType::Utf8View),
)
.otherwise(Literal(ScalarValue::Date32(None), None))
.unwrap(),
Expand All @@ -159,10 +188,22 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
&c_name,
when(
c.clone().is_not_null().and(
cast(c.clone(), DataType::Int32).rem(lit(10)).gt(lit(7)),
cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Int32, true)),
)
.rem(lit(10))
.gt(lit(7)),
),
regexp_replace(
cast(c.clone(), DataType::Utf8View),
cast(
c.clone(),
Arc::new(Field::new(
&c_name,
DataType::Utf8View,
true,
)),
),
lit("1"),
lit("a"),
None,
Expand All @@ -179,11 +220,21 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
&c_name,
try_cast(
to_timestamp(vec![c.clone(), lit("%Y-%m-%d %H:%M:%S")]),
DataType::Timestamp(Nanosecond, Some("UTC".into())),
Arc::new(Field::new(
&c_name,
DataType::Timestamp(Nanosecond, Some("UTC".into())),
true,
)),
),
)
.unwrap()
.with_column(&c_name, try_cast(c.clone(), DataType::Date32))
.with_column(
&c_name,
try_cast(
c.clone(),
Arc::new(Field::new(&c_name, DataType::Date32, true)),
),
)
.unwrap()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl TableProvider for CustomProvider {
Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
Expr::Cast(Cast { expr, field: _ }) => match expr.deref() {
Expr::Literal(lit_value, _) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
Expand Down
12 changes: 9 additions & 3 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {

// the arg2 parameter is a complex expr, but it can be evaluated to the literal value
let alias_expr = Expr::Alias(Alias::new(
cast(lit(0.5), DataType::Float32),
cast(
lit(0.5),
Arc::new(Field::new("arg_2", DataType::Float32, true)),
),
None::<&str>,
"arg_2".to_string(),
));
Expand All @@ -463,7 +466,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
);

let alias_expr = Expr::Alias(Alias::new(
cast(lit(0.1), DataType::Float32),
cast(
lit(0.1),
Arc::new(Field::new("arg_2", DataType::Float32, true)),
),
None::<&str>,
"arg_2".to_string(),
));
Expand Down Expand Up @@ -1129,7 +1135,7 @@ async fn test_fn_substr() -> Result<()> {

#[tokio::test]
async fn test_cast() -> Result<()> {
let expr = cast(col("b"), DataType::Float64);
let expr = cast(col("b"), Arc::new(Field::new("b", DataType::Float64, true)));
let batches = get_batches(expr).await?;

assert_snapshot!(
Expand Down
30 changes: 25 additions & 5 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2071,7 +2071,13 @@ async fn cast_expr_test() -> Result<()> {
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
.with_column(
"sum",
cast(
col("c2") + col("c3"),
Arc::new(Field::new("sum", DataType::Int64, false)),
),
)?;

let df_results = df.clone().collect().await?;
df.clone().show().await?;
Expand Down Expand Up @@ -2173,7 +2179,13 @@ async fn cache_test() -> Result<()> {
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
.with_column(
"sum",
cast(
col("c2") + col("c3"),
Arc::new(Field::new("sum", DataType::Int64, false)),
),
)?;

let cached_df = df.clone().cache().await?;

Expand Down Expand Up @@ -2671,8 +2683,13 @@ async fn write_table_with_order() -> Result<()> {
.unwrap();

// Ensure the column type matches the target table
write_df =
write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?;
write_df = write_df.with_column(
"tablecol1",
cast(
col("tablecol1"),
Arc::new(Field::new("tablecol1", DataType::Utf8View, false)),
),
)?;

let sql_str =
"create external table data(tablecol1 varchar) stored as parquet location '"
Expand Down Expand Up @@ -4687,7 +4704,10 @@ async fn consecutive_projection_same_schema() -> Result<()> {
let df = df
.with_column(
"t",
cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32),
cast(
Expr::Literal(ScalarValue::Null, None),
Arc::new(Field::new("t", DataType::Int32, true)),
),
)
.unwrap();
df.clone().show().await.unwrap();
Expand Down
15 changes: 12 additions & 3 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
}

fn cast_to_int64_expr(expr: Expr) -> Expr {
Expr::Cast(Cast::new(expr.into(), DataType::Int64))
Expr::Cast(Cast::new(
expr.into(),
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
))
}

fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
Expand Down Expand Up @@ -747,8 +750,14 @@ fn test_simplify_concat() -> Result<()> {
#[test]
fn test_simplify_cycles() {
// cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
let expr = cast(now(), DataType::Int64)
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
let expr = cast(
now(),
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
)
.lt(cast(
to_timestamp(vec![lit(0)]),
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
) + lit(i64::MAX));
let expected = lit(true);
test_simplify_with_cycle_count(expr, expected, 3);
}
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,12 @@ impl ScalarUDFImpl for CastToI64UDF {
} else {
// need to use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: DataType::Int64,
expr: Box::new(arg.clone()),
field: Arc::new(Field::new(
"cast_to_i64",
DataType::Int64,
info.nullable(&arg)?,
)),
})
};
// return the newly written argument to DataFusion
Expand Down
18 changes: 14 additions & 4 deletions datafusion/datasource-parquet/src/row_group_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,13 @@ mod tests {
.with_scale(0)
.with_precision(9);
let schema_descr = get_test_schema_descr(vec![field]);
let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast(
let expr = cast(
col("c1"),
Arc::new(Field::new("c1", Decimal128(11, 2), true)),
)
.gt(cast(
lit(ScalarValue::Decimal128(Some(500), 5, 2)),
Decimal128(11, 2),
Arc::new(Field::new("c1", Decimal128(11, 2), true)),
));
let expr = logical2physical(&expr, &schema);
let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap();
Expand Down Expand Up @@ -1023,7 +1027,10 @@ mod tests {
.with_byte_len(16);
let schema_descr = get_test_schema_descr(vec![field]);
// cast the type of c1 to decimal(28,3)
let left = cast(col("c1"), Decimal128(28, 3));
let left = cast(
col("c1"),
Arc::new(Field::new("c1", Decimal128(28, 3), true)),
);
let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let expr = logical2physical(&expr, &schema);
let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap();
Expand Down Expand Up @@ -1101,7 +1108,10 @@ mod tests {
.with_byte_len(16);
let schema_descr = get_test_schema_descr(vec![field]);
// cast the type of c1 to decimal(28,3)
let left = cast(col("c1"), Decimal128(28, 3));
let left = cast(
col("c1"),
Arc::new(Field::new("c1", Decimal128(28, 3), true)),
);
let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3)));
let expr = logical2physical(&expr, &schema);
let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap();
Expand Down
Loading
Loading