Skip to content

Commit

Permalink
feat: support projection pushdown for datafusion iceberg (#594)
Browse files Browse the repository at this point in the history
* support projection pushdown for datafusion iceberg

* support projection pushdown for datafusion iceberg

* fix ci

* fix field id

* remove depencences

* remove depencences
  • Loading branch information
FANNG1 authored Sep 13, 2024
1 parent eae9464 commit cde35ab
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ iceberg = { version = "0.3.0", path = "./crates/iceberg" }
iceberg-catalog-rest = { version = "0.3.0", path = "./crates/catalog/rest" }
iceberg-catalog-hms = { version = "0.3.0", path = "./crates/catalog/hms" }
iceberg-catalog-memory = { version = "0.3.0", path = "./crates/catalog/memory" }
iceberg-datafusion = { version = "0.3.0", path = "./crates/integrations/datafusion" }
itertools = "0.13"
log = "0.4"
mockito = "1"
Expand Down
39 changes: 35 additions & 4 deletions crates/integrations/datafusion/src/physical_plan/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::vec;

use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
Expand All @@ -44,17 +45,25 @@ pub(crate) struct IcebergTableScan {
/// Stores certain, often expensive to compute,
/// plan properties used in query optimization.
plan_properties: PlanProperties,
/// Projection column names, None means all columns
projection: Option<Vec<String>>,
}

impl IcebergTableScan {
/// Creates a new [`IcebergTableScan`] object.
pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self {
pub(crate) fn new(
table: Table,
schema: ArrowSchemaRef,
projection: Option<&Vec<usize>>,
) -> Self {
let plan_properties = Self::compute_properties(schema.clone());
let projection = get_column_names(schema.clone(), projection);

Self {
table,
schema,
plan_properties,
projection,
}
}

Expand Down Expand Up @@ -100,7 +109,7 @@ impl ExecutionPlan for IcebergTableScan {
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let fut = get_batch_stream(self.table.clone());
let fut = get_batch_stream(self.table.clone(), self.projection.clone());
let stream = futures::stream::once(fut).try_flatten();

Ok(Box::pin(RecordBatchStreamAdapter::new(
Expand All @@ -116,7 +125,13 @@ impl DisplayAs for IcebergTableScan {
_t: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(f, "IcebergTableScan")
write!(
f,
"IcebergTableScan projection:[{}]",
self.projection
.clone()
.map_or(String::new(), |v| v.join(","))
)
}
}

Expand All @@ -127,8 +142,13 @@ impl DisplayAs for IcebergTableScan {
/// and then converts it into a stream of Arrow [`RecordBatch`]es.
async fn get_batch_stream(
table: Table,
column_names: Option<Vec<String>>,
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
let table_scan = table.scan().build().map_err(to_datafusion_error)?;
let scan_builder = match column_names {
Some(column_names) => table.scan().select(column_names),
None => table.scan().select_all(),
};
let table_scan = scan_builder.build().map_err(to_datafusion_error)?;

let stream = table_scan
.to_arrow()
Expand All @@ -138,3 +158,14 @@ async fn get_batch_stream(

Ok(Box::pin(stream))
}

fn get_column_names(
schema: ArrowSchemaRef,
projection: Option<&Vec<usize>>,
) -> Option<Vec<String>> {
projection.map(|v| {
v.iter()
.map(|p| schema.field(*p).name().clone())
.collect::<Vec<String>>()
})
}
3 changes: 2 additions & 1 deletion crates/integrations/datafusion/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ impl TableProvider for IcebergTableProvider {
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.schema.clone(),
projection,
)))
}
}
109 changes: 96 additions & 13 deletions crates/integrations/datafusion/tests/integration_datafusion_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::vec;

use datafusion::arrow::array::{Array, StringArray};
use datafusion::arrow::datatypes::DataType;
use datafusion::execution::context::SessionContext;
use iceberg::io::FileIOBuilder;
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Type};
use iceberg::{Catalog, NamespaceIdent, Result, TableCreation};
use iceberg_catalog_memory::MemoryCatalog;
use iceberg_datafusion::IcebergCatalogProvider;
Expand All @@ -39,6 +41,13 @@ fn get_iceberg_catalog() -> MemoryCatalog {
MemoryCatalog::new(file_io, Some(temp_path()))
}

fn get_struct_type() -> StructType {
StructType::new(vec![
NestedField::required(4, "s_foo1", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(5, "s_foo2", Type::Primitive(PrimitiveType::String)).into(),
])
}

async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent) -> Result<()> {
let properties = HashMap::new();

Expand All @@ -47,14 +56,21 @@ async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent)
Ok(())
}

fn set_table_creation(location: impl ToString, name: impl ToString) -> Result<TableCreation> {
let schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(),
])
.build()?;
fn get_table_creation(
location: impl ToString,
name: impl ToString,
schema: Option<Schema>,
) -> Result<TableCreation> {
let schema = match schema {
None => Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(),
])
.build()?,
Some(schema) => schema,
};

let creation = TableCreation::builder()
.location(location.to_string())
Expand All @@ -72,7 +88,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
let namespace = NamespaceIdent::new("test_provider_get_table_schema".to_string());
set_test_namespace(&iceberg_catalog, &namespace).await?;

let creation = set_table_creation(temp_path(), "my_table")?;
let creation = get_table_creation(temp_path(), "my_table", None)?;
iceberg_catalog.create_table(&namespace, creation).await?;

let client = Arc::new(iceberg_catalog);
Expand All @@ -87,7 +103,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
let table = schema.table("my_table").await.unwrap().unwrap();
let table_schema = table.schema();

let expected = [("foo", &DataType::Int32), ("bar", &DataType::Utf8)];
let expected = [("foo1", &DataType::Int32), ("foo2", &DataType::Utf8)];

for (field, exp) in table_schema.fields().iter().zip(expected.iter()) {
assert_eq!(field.name(), exp.0);
Expand All @@ -104,7 +120,7 @@ async fn test_provider_list_table_names() -> Result<()> {
let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string());
set_test_namespace(&iceberg_catalog, &namespace).await?;

let creation = set_table_creation(temp_path(), "my_table")?;
let creation = get_table_creation(temp_path(), "my_table", None)?;
iceberg_catalog.create_table(&namespace, creation).await?;

let client = Arc::new(iceberg_catalog);
Expand All @@ -130,7 +146,6 @@ async fn test_provider_list_schema_names() -> Result<()> {
let namespace = NamespaceIdent::new("test_provider_list_schema_names".to_string());
set_test_namespace(&iceberg_catalog, &namespace).await?;

set_table_creation("test_provider_list_schema_names", "my_table")?;
let client = Arc::new(iceberg_catalog);
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);

Expand All @@ -147,3 +162,71 @@ async fn test_provider_list_schema_names() -> Result<()> {
.all(|item| result.contains(&item.to_string())));
Ok(())
}

#[tokio::test]
async fn test_table_projection() -> Result<()> {
let iceberg_catalog = get_iceberg_catalog();
let namespace = NamespaceIdent::new("ns".to_string());
set_test_namespace(&iceberg_catalog, &namespace).await?;

let schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(),
NestedField::optional(3, "foo3", Type::Struct(get_struct_type())).into(),
])
.build()?;
let creation = get_table_creation(temp_path(), "t1", Some(schema))?;
iceberg_catalog.create_table(&namespace, creation).await?;

let client = Arc::new(iceberg_catalog);
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);

let ctx = SessionContext::new();
ctx.register_catalog("catalog", catalog);
let table_df = ctx.table("catalog.ns.t1").await.unwrap();

let records = table_df
.clone()
.explain(false, false)
.unwrap()
.collect()
.await
.unwrap();
assert_eq!(1, records.len());
let record = &records[0];
// the first column is plan_type, the second column plan string.
let s = record
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(2, s.len());
// the first row is logical_plan, the second row is physical_plan
assert_eq!(
"IcebergTableScan projection:[foo1,foo2,foo3]",
s.value(1).trim()
);

// datafusion doesn't support query foo3.s_foo1, use foo3 instead
let records = table_df
.select_columns(&["foo1", "foo3"])
.unwrap()
.explain(false, false)
.unwrap()
.collect()
.await
.unwrap();
assert_eq!(1, records.len());
let record = &records[0];
let s = record
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(2, s.len());
assert_eq!("IcebergTableScan projection:[foo1,foo3]", s.value(1).trim());

Ok(())
}

0 comments on commit cde35ab

Please sign in to comment.