Skip to content

Commit

Permalink
Support (order by / sort) for DataFrameWriteOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuqi-lucas committed Dec 23, 2024
1 parent 8fd792f commit 6eb83f7
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 9 deletions.
185 changes: 177 additions & 8 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,15 @@ use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use crate::prelude::{CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, SessionContext};

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use arrow_array::Int32Array;
use arrow_schema::{Schema, SchemaRef};
use datafusion_common::config::{CsvOptions, JsonOptions};
use datafusion_common::{
exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues,
SchemaError, UnnestOptions,
};
use datafusion_common::{assert_batches_eq, exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions};
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{case, is_null, lit, SortExpr};
use datafusion_expr::{
Expand All @@ -63,6 +61,7 @@ use datafusion_functions_aggregate::expr_fn::{
};

use async_trait::async_trait;
use tempfile::TempDir;
use datafusion_catalog::Session;

/// Contains options that control how data is
Expand All @@ -77,6 +76,9 @@ pub struct DataFrameWriteOptions {
/// Sets which columns should be used for hive-style partitioned writes by name.
/// Can be set to empty vec![] for non-partitioned writes.
partition_by: Vec<String>,
/// Sets which columns should be used for sorting the output by name.
/// Can be set to empty vec![] for non-sorted writes.
sort_by: Vec<SortExpr>,
}

impl DataFrameWriteOptions {
Expand All @@ -86,6 +88,7 @@ impl DataFrameWriteOptions {
insert_op: InsertOp::Append,
single_file_output: false,
partition_by: vec![],
sort_by: vec![],
}
}

Expand All @@ -106,6 +109,12 @@ impl DataFrameWriteOptions {
self.partition_by = partition_by;
self
}

/// Sets the sort_by columns for output sorting
pub fn with_sort_by(mut self, sort_by: Vec<SortExpr>) -> Self {
self.sort_by = sort_by;
self
}
}

impl Default for DataFrameWriteOptions {
Expand Down Expand Up @@ -1517,8 +1526,15 @@ impl DataFrame {
write_options: DataFrameWriteOptions,
) -> Result<Vec<RecordBatch>, DataFusionError> {
let arrow_schema = Schema::from(self.schema());

let plan = if write_options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan).sort(write_options.sort_by)?.build()?
};

let plan = LogicalPlanBuilder::insert_into(
self.plan,
plan,
table_name.to_owned(),
&arrow_schema,
write_options.insert_op,
Expand Down Expand Up @@ -1577,8 +1593,14 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan).sort(options.sort_by)?.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
HashMap::new(),
Expand Down Expand Up @@ -1638,8 +1660,14 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan).sort(options.sort_by)?.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
Default::default(),
Expand Down Expand Up @@ -4058,3 +4086,144 @@ mod tests {
Ok(())
}
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_parquet_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3,2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5,6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.parquet");

write_df.clone().write_parquet(test_path.to_str().unwrap(), DataFrameWriteOptions::new().with_sort_by(vec![col("a").sort(true, true)]),
None,).await?;

let ctx = SessionContext::new();
ctx.register_parquet("data", test_path.to_str().unwrap(), ParquetReadOptions::default())
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_csv_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3,2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5,6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.csv");

write_df.clone().write_csv(test_path.to_str().unwrap(), DataFrameWriteOptions::new().with_sort_by(vec![col("a").sort(true, true)]),
None,).await?;

let ctx = SessionContext::new();
ctx.register_csv("data", test_path.to_str().unwrap(), CsvReadOptions::new().schema(&schema))
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}

// Test issue: https://github.com/apache/datafusion/issues/13873
#[tokio::test]
async fn write_json_with_order() -> Result<()> {
let tmp_dir = TempDir::new()?;
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let ctx = SessionContext::new();
let write_df = ctx.read_batch(RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 5, 7, 3,2])),
Arc::new(Int32Array::from(vec![2, 3, 4, 5,6])),
],
)?)?;

let test_path = tmp_dir.path().join("test.json");

write_df.clone().write_json(test_path.to_str().unwrap(), DataFrameWriteOptions::new().with_sort_by(vec![col("a").sort(true, true)]),
None,).await?;

let ctx = SessionContext::new();
ctx.register_json("data", test_path.to_str().unwrap(), NdJsonReadOptions::default().schema(&schema))
.await?;

let df = ctx.sql("SELECT * FROM data").await?;
let results = df.collect().await?;

assert_batches_eq!(
&[
"+---+---+",
"| a | b |",
"+---+---+",
"| 1 | 2 |",
"| 2 | 6 |",
"| 3 | 5 |",
"| 5 | 3 |",
"| 7 | 4 |",
"+---+---+",
],
&results
);
Ok(())
}
8 changes: 7 additions & 1 deletion datafusion/core/src/dataframe/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ impl DataFrame {

let file_type = format_as_file_type(format);

let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan).sort(options.sort_by)?.build()?
};

let plan = LogicalPlanBuilder::copy_to(
self.plan,
plan,
path.into(),
file_type,
Default::default(),
Expand Down

0 comments on commit 6eb83f7

Please sign in to comment.