Skip to content

Commit

Permalink
adding main function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon Agmon committed Aug 27, 2024
1 parent f9c92b7 commit e8bd953
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 6 deletions.
245 changes: 241 additions & 4 deletions crates/integrations/datafusion/src/physical_plan/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::error::Result as DFResult;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::logical_expr::{BinaryExpr, Operator};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties,
};
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue;
use futures::{Stream, TryStreamExt};
use iceberg::expr::{Predicate, Reference};
use iceberg::spec::Datum;
use iceberg::table::Table;

use crate::to_datafusion_error;
Expand All @@ -44,17 +49,19 @@ pub(crate) struct IcebergTableScan {
/// Stores certain, often expensive to compute,
/// plan properties used in query optimization.
plan_properties: PlanProperties,
predicates: Option<Predicate>,
}

impl IcebergTableScan {
/// Creates a new [`IcebergTableScan`] object.
pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self {
pub(crate) fn new(table: Table, schema: ArrowSchemaRef, filters: &[Expr]) -> Self {
let plan_properties = Self::compute_properties(schema.clone());

let predicates = convert_filters_to_predicate(filters);
Self {
table,
schema,
plan_properties,
predicates,
}
}

Expand Down Expand Up @@ -100,7 +107,7 @@ impl ExecutionPlan for IcebergTableScan {
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let fut = get_batch_stream(self.table.clone());
let fut = get_batch_stream(self.table.clone(), self.predicates.clone());
let stream = futures::stream::once(fut).try_flatten();

Ok(Box::pin(RecordBatchStreamAdapter::new(
Expand All @@ -127,8 +134,13 @@ impl DisplayAs for IcebergTableScan {
/// and then converts it into a stream of Arrow [`RecordBatch`]es.
async fn get_batch_stream(
table: Table,
predicates: Option<Predicate>,
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
let table_scan = table.scan().build().map_err(to_datafusion_error)?;
let mut scan_builder = table.scan();
if let Some(pred) = predicates {
scan_builder = scan_builder.with_filter(pred);
}
let table_scan = scan_builder.build().map_err(to_datafusion_error)?;

let stream = table_scan
.to_arrow()
Expand All @@ -138,3 +150,228 @@ async fn get_batch_stream(

Ok(Box::pin(stream))
}

/// convert DataFusion filters ([`Expr`]) to an iceberg [`Predicate`]
/// if none of the filters could be converted, return `None`
/// if the conversion was successful, return the converted predicates combined with an AND operator
fn convert_filters_to_predicate(filters: &[Expr]) -> Option<Predicate> {
filters
.iter()
.filter_map(expr_to_predicate)
.reduce(Predicate::and)
}

/// Converts a DataFusion [`Expr`] to an Iceberg [`Predicate`].
///
/// This function handles the conversion of certain DataFusion expression types
/// to their corresponding Iceberg predicates. It supports the following cases:
///
/// 1. Simple binary expressions (e.g., "column < value")
/// 2. Compound AND expressions (e.g., "x < 1 AND y > 10")
/// 3. Compound OR expressions (e.g., "x < 1 OR y > 10")
///
/// For AND expressions, if one part of the expression can't be converted,
/// the function will still return a predicate for the part that can be converted.
/// For OR expressions, if any part can't be converted, the entire expression
/// will fail to convert.
///
/// # Arguments
///
/// * `expr` - A reference to a DataFusion [`Expr`] to be converted.
///
/// # Returns
///
/// * `Some(Predicate)` if the expression could be successfully converted.
/// * `None` if the expression couldn't be converted to an Iceberg predicate.
///
fn expr_to_predicate(expr: &Expr) -> Option<Predicate> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match (left.as_ref(), op, right.as_ref()) {
// first option: x < 1
(Expr::Column(col), op, Expr::Literal(lit)) => {
let reference = Reference::new(col.name.clone());
let datum = scalar_value_to_datum(lit)?;
Some(binary_op_to_predicate(reference, op, datum))
}
// second option (inner AND): x < 1 AND y > 10
// if its an AND expression and one predicate fails, we can still go with the other one
(left_expr, Operator::And, right_expr) => {
let left_pred = expr_to_predicate(&left_expr.clone());
let right_pred = expr_to_predicate(&right_expr.clone());
match (left_pred, right_pred) {
(Some(left), Some(right)) => Some(Predicate::and(left, right)),
(Some(left), None) => Some(left),
(None, Some(right)) => Some(right),
(None, None) => None,
}
}
// third option (inner OR): x < 1 OR y > 10
// if one is unsuported, we need to fail the predicate
(Expr::BinaryExpr(left_expr), Operator::Or, Expr::BinaryExpr(right_expr)) => {
let left_pred = expr_to_predicate(&Expr::BinaryExpr(left_expr.clone()))?;
let right_pred = expr_to_predicate(&Expr::BinaryExpr(right_expr.clone()))?;
Some(Predicate::or(left_pred, right_pred))
}
_ => None,
}
}
_ => None,
}
}

/// convert the data fusion Exp to an iceberg [`Predicate`]
fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate {
match op {
Operator::Eq => reference.equal_to(datum),
Operator::NotEq => reference.not_equal_to(datum),
Operator::Lt => reference.less_than(datum),
Operator::LtEq => reference.less_than_or_equal_to(datum),
Operator::Gt => reference.greater_than(datum),
Operator::GtEq => reference.greater_than_or_equal_to(datum),
_ => Predicate::AlwaysTrue,
}
}
/// convert a DataFusion scalar value to an iceberg [`Datum`]
fn scalar_value_to_datum(value: &ScalarValue) -> Option<Datum> {
match value {
ScalarValue::Int8(Some(v)) => Some(Datum::long(*v as i64)),
ScalarValue::Int16(Some(v)) => Some(Datum::long(*v as i64)),
ScalarValue::Int32(Some(v)) => Some(Datum::long(*v as i64)),
ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)),
ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)),
ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)),
ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())),
// Add more cases as needed
_ => None,
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::DFSchema;
use datafusion::prelude::SessionContext;

fn create_test_schema() -> DFSchema {
let arrow_schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Utf8, false),
]);
DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
}
fn create_test_schema_b() -> DFSchema {
let arrow_schema = Schema::new(vec![
Field::new("xxx", DataType::Int32, false),
Field::new("yyy", DataType::Utf8, false),
Field::new("zzz", DataType::Int32, false),
]);
DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
}

#[test]
fn test_predicate_conversion_with_single_condition() {
let sql = "foo > 1";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
assert_eq!(
predicate,
Reference::new("foo").greater_than(Datum::long(1))
);
}

#[test]
fn test_predicate_conversion_with_multiple_conditions() {
let sql = "foo > 1 and bar = 'test'";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let inner_predicate = Predicate::and(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("bar").equal_to(Datum::string("test")),
);
assert_eq!(predicate, inner_predicate);
}

#[test]
fn test_predicate_conversion_with_multiple_binary_expr() {
let sql = "(foo > 1 and bar = 'test') or foo < 0 ";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let inner_predicate = Predicate::and(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("bar").equal_to(Datum::string("test")),
);
let expected_predicate = Predicate::or(
inner_predicate,
Reference::new("foo").less_than(Datum::long(0)),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_unsupported_condition_not() {
let sql = "xxx > 1 and yyy is not null and zzz < 0 ";
let df_schema = create_test_schema_b();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let expected_predicate = Predicate::and(
Reference::new("xxx").greater_than(Datum::long(1)),
Reference::new("zzz").less_than(Datum::long(0)),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_unsupported_condition_and() {
let sql = "(xxx > 1 and yyy in ('test', 'test2')) and zzz < 0 ";
let df_schema = create_test_schema_b();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let expected_predicate = Predicate::and(
Reference::new("xxx").greater_than(Datum::long(1)),
Reference::new("zzz").less_than(Datum::long(0)),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_unsupported_condition_or() {
let sql = "(foo > 1 and bar in ('test', 'test2')) or foo < 0 ";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let expected_predicate = Predicate::or(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("foo").less_than(Datum::long(0)),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_unsupported_expr() {
let sql = "(xxx > 1 or yyy in ('test', 'test2')) and zzz < 0 ";
let df_schema = create_test_schema_b();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let predicate = convert_filters_to_predicate(&[expr]).unwrap();
let expected_predicate = Reference::new("zzz").less_than(Datum::long(0));
assert_eq!(predicate, expected_predicate);
}
}
24 changes: 22 additions & 2 deletions crates/integrations/datafusion/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
use std::any::Any;
use std::sync::Arc;

use crate::physical_plan::scan::IcebergTableScan;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::{BinaryExpr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use iceberg::arrow::schema_to_arrow_schema;
use iceberg::table::Table;
use iceberg::{Catalog, NamespaceIdent, Result, TableIdent};

use crate::physical_plan::scan::IcebergTableScan;

/// Represents a [`TableProvider`] for the Iceberg [`Catalog`],
/// managing access to a [`Table`].
pub(crate) struct IcebergTableProvider {
Expand Down Expand Up @@ -82,6 +82,26 @@ impl TableProvider for IcebergTableProvider {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.schema.clone(),
_filters,
)))
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> std::result::Result<Vec<TableProviderFilterPushDown>, datafusion::error::DataFusionError>
{
let filter_support = filters
.iter()
.map(|e| {
if let Expr::BinaryExpr(BinaryExpr { .. }) = e {
TableProviderFilterPushDown::Inexact
} else {
TableProviderFilterPushDown::Unsupported
}
})
.collect::<Vec<TableProviderFilterPushDown>>();

Ok(filter_support)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,39 @@ async fn test_provider_list_schema_names() -> Result<()> {
.all(|item| result.contains(&item.to_string())));
Ok(())
}
#[tokio::test]
async fn test_table_scan() -> Result<()> {
let iceberg_catalog = get_iceberg_catalog();
let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string());
set_test_namespace(&iceberg_catalog, &namespace).await?;
let creation = set_table_creation(temp_path(), "my_table")?;
let new_table = iceberg_catalog.create_table(&namespace, creation).await?;
let client = Arc::new(iceberg_catalog);
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);
let ctx = SessionContext::new();

ctx.register_catalog("catalog", catalog);
let df = ctx
.sql("select * from catalog.test_provider_list_table_names.my_table where (foo > 1 and bar = 'test') or foo < 0 ")
.await
.unwrap();

let compute_result = df.collect().await;
if let Ok(df) = compute_result {
println!("==> compute_result OK: {:?}", df);
} else {
println!(
"==> compute_result ERROR: {:?}",
compute_result.err().unwrap()
);
}
let provider = ctx.catalog("catalog").unwrap();
let schema = provider.schema("test_provider_list_table_names").unwrap();

let expected = vec!["my_table"];
let result = schema.table_names();

assert_eq!(result, expected);

Ok(())
}

0 comments on commit e8bd953

Please sign in to comment.