Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jan 31, 2025
1 parent 693ba42 commit 9f42a35
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 4 deletions.
353 changes: 351 additions & 2 deletions datafusion/core/tests/sql/metadata_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::{Arc, Mutex};

use arrow_array::{ArrayRef, StringArray, UInt64Array};
use arrow_array::{record_batch, ArrayRef, StringArray, UInt64Array};
use async_trait::async_trait;
use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::datasource::{MemTable, TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::TaskContext;
use datafusion::physical_expr::EquivalenceProperties;
Expand Down Expand Up @@ -295,6 +295,69 @@ impl ExecutionPlan for CustomExec {
}
}

#[derive(Debug)]
struct MetadataColumnTableProvider {
inner: MemTable,
}

impl MetadataColumnTableProvider {
fn new(batch: RecordBatch) -> Self {
let inner = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
Self { inner }
}
}

#[async_trait::async_trait]
impl TableProvider for MetadataColumnTableProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
self.inner.schema()
}

fn metadata_columns(&self) -> Option<SchemaRef> {
let schema = self.schema();
let metadata_columns = schema
.fields()
.iter()
.filter(|f| {
if let Some(v) = f.metadata().get("datafusion.system_column") {
v.to_lowercase().starts_with("t")
} else {
false
}
})
.collect::<Vec<_>>();
if metadata_columns.is_empty() {
None
} else {
Some(Arc::new(Schema::new(
metadata_columns
.iter()
.cloned()
.cloned()
.collect::<Vec<_>>(),
)))
}
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.inner.scan(state, projection, filters, limit).await
}
}

#[tokio::test]
async fn select_conflict_name() {
// when reading csv, json or parquet, normal column name may be same as metadata column name,
Expand Down Expand Up @@ -464,3 +527,289 @@ async fn select_metadata_column() {
];
assert_batches_sorted_eq!(expected, &batchs);
}

#[tokio::test]
async fn test_select_system_column() {
let batch = record_batch!(
("id", UInt8, [1, 2, 3]),
("bank_account", UInt64, [9000, 100, 1000]),
("_rowid", UInt32, [0, 1, 2]),
("_file", Utf8, ["file-0", "file-1", "file-2"])
)
.unwrap();
let batch = batch
.with_schema(Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt8, true),
Field::new("bank_account", DataType::UInt64, true),
Field::new("_rowid", DataType::UInt32, true).with_metadata(
[("datafusion.system_column".to_string(), "true".to_string())]
.iter()
.cloned()
.collect(),
),
Field::new("_file", DataType::Utf8, true).with_metadata(
[("datafusion.system_column".to_string(), "true".to_string())]
.iter()
.cloned()
.collect(),
),
])))
.unwrap();

let ctx = SessionContext::new_with_config(
SessionConfig::new().with_information_schema(true),
);
let _ = ctx
.register_table("test", Arc::new(MetadataColumnTableProvider::new(batch)))
.unwrap();

let select0 = "SELECT * FROM test order by id";
let df = ctx.sql(select0).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+--------------+",
"| id | bank_account |",
"+----+--------------+",
"| 1 | 9000 |",
"| 2 | 100 |",
"| 3 | 1000 |",
"+----+--------------+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select1 = "SELECT _rowid FROM test order by _rowid";
let df = ctx.sql(select1).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+",
"| _rowid |",
"+--------+",
"| 0 |",
"| 1 |",
"| 2 |",
"+--------+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select2 = "SELECT _rowid, id FROM test order by _rowid";
let df = ctx.sql(select2).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+----+",
"| _rowid | id |",
"+--------+----+",
"| 0 | 1 |",
"| 1 | 2 |",
"| 2 | 3 |",
"+--------+----+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select3 = "SELECT _rowid, id FROM test WHERE _rowid = 0";
let df = ctx.sql(select3).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+----+",
"| _rowid | id |",
"+--------+----+",
"| 0 | 1 |",
"+--------+----+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select4 = "SELECT _rowid FROM test LIMIT 1";
let df = ctx.sql(select4).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+",
"| _rowid |",
"+--------+",
"| 0 |",
"+--------+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select5 = "SELECT _rowid, id FROM test WHERE _rowid % 2 = 1";
let df = ctx.sql(select5).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+----+",
"| _rowid | id |",
"+--------+----+",
"| 1 | 2 |",
"+--------+----+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select6 = "SELECT _rowid, _file FROM test order by _rowid";
let df = ctx.sql(select6).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+--------+--------+",
"| _rowid | _file |",
"+--------+--------+",
"| 0 | file-0 |",
"| 1 | file-1 |",
"| 2 | file-2 |",
"+--------+--------+",
];
assert_batches_sorted_eq!(expected, &batchs);

let select6 = "SELECT id FROM test order by _rowid asc";
let df = ctx.sql(select6).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+",
"| id |",
"+----+",
"| 1 |",
"| 2 |",
"| 3 |",
"+----+",
];
assert_batches_sorted_eq!(expected, &batchs);

let show_columns = "show columns from test;";
let df_columns = ctx.sql(show_columns).await.unwrap();
let batchs = df_columns
.select(vec![col("column_name"), col("data_type")])
.unwrap()
.collect()
.await
.unwrap();
let expected = [
"+--------------+-----------+",
"| column_name | data_type |",
"+--------------+-----------+",
"| id | UInt8 |",
"| bank_account | UInt64 |",
"+--------------+-----------+",
];
assert_batches_sorted_eq!(expected, &batchs);

let batch = record_batch!(
("other_id", UInt8, [1, 2, 3]),
("bank_account", UInt64, [9, 10, 11]),
("_rowid", UInt32, [10, 11, 12]) // not a system column!
)
.unwrap();
let _ = ctx
.register_table("test2", Arc::new(MetadataColumnTableProvider::new(batch)))
.unwrap();

// Normally _rowid would be a name conflict and throw an error during planning.
// But when it's a conflict between a system column and a non system column,
// the non system column should be used.
let select7 =
"SELECT id, other_id, _rowid FROM test INNER JOIN test2 ON id = other_id";
let df = ctx.sql(select7).await.unwrap();
let batchs = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+----------+---------+",
"| id | other_id | _rowid |",
"+----+----------+---------+",
"| 1 | 1 | 10 |",
"| 2 | 2 | 11 |",
"| 3 | 3 | 12 |",
"+----+----------+---------+",
];
assert_batches_sorted_eq!(expected, &batchs);

// Sanity check: for other columns we do get a conflict
let select7 =
"SELECT id, other_id, bank_account FROM test INNER JOIN test2 ON id = other_id";
assert!(ctx.sql(select7).await.is_err());

// Demonstrate that we can join on _rowid
let batch = record_batch!(
("other_id", UInt8, [2, 3, 4]),
("_rowid", UInt32, [2, 3, 4])
)
.unwrap();
let batch = batch
.with_schema(Arc::new(Schema::new(vec![
Field::new("other_id", DataType::UInt8, true),
Field::new("_rowid", DataType::UInt32, true).with_metadata(
[("datafusion.system_column".to_string(), "true".to_string())]
.iter()
.cloned()
.collect(),
),
])))
.unwrap();
let _ = ctx
.register_table("test2", Arc::new(MetadataColumnTableProvider::new(batch)))
.unwrap();

let select8 = "SELECT id, other_id, _rowid FROM test JOIN test2 ON _rowid = _rowid";
let df = ctx.sql(select8).await.unwrap();
let batches = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+----------+---------+",
"| id | other_id | _rowid |",
"+----+----------+---------+",
"| 2 | 2 | 2 |",
"+----+----------+---------+",
];
assert_batches_sorted_eq!(expected, &batches);

// Once passed through a projection, system columns are no longer available
let select9 = r"
WITH cte AS (SELECT * FROM test)
SELECT * FROM cte
";
let df = ctx.sql(select9).await.unwrap();
let batches = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+----------+---------+",
"| id | other_id | _rowid |",
"+----+----------+---------+",
"| 2 | 2 | 2 |",
"+----+----------+---------+",
];
assert_batches_sorted_eq!(expected, &batches);
let select10 = r"
WITH cte AS (SELECT * FROM test)
SELECT _rowid FROM cte
";
let df = ctx.sql(select10).await.unwrap();
let batches = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+----------+---------+",
"| id | other_id | _rowid |",
"+----+----------+---------+",
"| 2 | 2 | 2 |",
"+----+----------+---------+",
];
assert_batches_sorted_eq!(expected, &batches);

// And if passed explicitly selected and passed through a projection
// they are no longer system columns.
let select11 = r"
WITH cte AS (SELECT id, _rowid FROM test)
SELECT * FROM cte
";
let df = ctx.sql(select11).await.unwrap();
let batches = df.collect().await.unwrap();
#[rustfmt::skip]
let expected = [
"+----+---------+",
"| id | _rowid |",
"+----+---------+",
"| 2 | 2 |",
"+----+---------+",
];
assert_batches_sorted_eq!(expected, &batches);
}
2 changes: 1 addition & 1 deletion parquet-testing

0 comments on commit 9f42a35

Please sign in to comment.