Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterKeDer committed Oct 4, 2024
1 parent 8f4232f commit aedeaaa
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ class RawDeltaTable:
starting_timestamp: Optional[str] = None,
ending_timestamp: Optional[str] = None,
) -> pyarrow.RecordBatchReader: ...
def datafusion_read(
self,
predicate: Optional[str] = None,
columns: Optional[List[str]] = None,
) -> None: ...

def rust_core_version() -> str: ...
def write_new_deltalake(
Expand Down
7 changes: 7 additions & 0 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,13 @@ def repair(
)
return json.loads(metrics)

def datafusion_read(
self,
predicate: Optional[str] = None,
columns: Optional[List[str]] = None,
) -> List[pyarrow.RecordBatch]:
return self._table.datafusion_read(predicate, columns)


class TableMerger:
"""API for various table `MERGE` commands."""
Expand Down
70 changes: 68 additions & 2 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod utils;
use std::collections::{HashMap, HashSet};
use std::future::IntoFuture;
use std::str::FromStr;
use std::sync::Arc;
use std::time;
use std::time::{SystemTime, UNIX_EPOCH};

Expand All @@ -17,12 +18,18 @@ use delta_kernel::expressions::Scalar;
use delta_kernel::schema::StructField;
use deltalake::arrow::compute::concat_batches;
use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
use deltalake::arrow::pyarrow::ToPyArrow;
use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator};
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
use deltalake::checkpoints::{cleanup_metadata, create_checkpoint};
use deltalake::datafusion::datasource::provider_as_source;
use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use deltalake::datafusion::physical_plan::ExecutionPlan;
use deltalake::datafusion::prelude::SessionContext;
use deltalake::delta_datafusion::DeltaDataChecker;
use deltalake::datafusion::prelude::{DataFrame, SessionContext};
use deltalake::delta_datafusion::{
DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig,
DeltaTableProvider,
};
use deltalake::errors::DeltaTableError;
use deltalake::kernel::{
scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType,
Expand Down Expand Up @@ -1232,6 +1239,65 @@ impl RawDeltaTable {
self._table.state = table.state;
Ok(serde_json::to_string(&metrics).unwrap())
}

#[pyo3(signature = (predicate = None, columns = None))]
pub fn datafusion_read(
&self,
py: Python,
predicate: Option<String>,
columns: Option<Vec<String>>,
) -> PyResult<PyObject> {
let batches = py.allow_threads(|| -> PyResult<_> {
let snapshot = self._table.snapshot().map_err(PythonError::from)?;
let log_store = self._table.log_store();

let scan_config = DeltaScanConfigBuilder::default()
.with_parquet_pushdown(false)
.build(snapshot)
.map_err(PythonError::from)?;

let provider = Arc::new(
DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config)
.map_err(PythonError::from)?,
);
let source = provider_as_source(provider);

let config = DeltaSessionConfig::default().into();
let session = SessionContext::new_with_config(config);
let state = session.state();

let maybe_filter = predicate
.map(|predicate| snapshot.parse_predicate_expression(predicate, &state))
.transpose()
.map_err(PythonError::from)?;

let filters = match &maybe_filter {
Some(filter) => vec![filter.clone()],
None => vec![],
};

let plan = LogicalPlanBuilder::scan_with_filters(UNNAMED_TABLE, source, None, filters)
.unwrap()
.build()
.unwrap();

let mut df = DataFrame::new(state, plan);

if let Some(filter) = maybe_filter {
df = df.filter(filter).unwrap();
}

if let Some(columns) = columns {
df = df
.select_columns(&columns.iter().map(String::as_str).collect::<Vec<_>>())
.unwrap();
}

Ok(rt().block_on(async { df.collect().await }).unwrap())
})?;

batches.to_pyarrow(py)
}
}

fn set_post_commithook_properties(
Expand Down
64 changes: 64 additions & 0 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,67 @@ def test_is_deltatable_with_storage_opts():
"DELTA_DYNAMO_TABLE_NAME": "custom_table_name",
}
assert DeltaTable.is_deltatable(table_path, storage_options=storage_options)


def test_datafusion_read_table():
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["1", "2", "3", "4", "5", "6", "7"],
"year": ["2020", "2020", "2020", "2021", "2021", "2021", "2021"],
"month": ["1", "2", "2", "4", "12", "12", "12"],
"day": ["1", "3", "5", "5", "4", "20", "20"],
}
actual = pa.Table.from_batches(dt.datafusion_read()).sort_by("value").to_pydict()
assert expected == actual


def test_datafusion_read_table_with_columns():
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["1", "2", "3", "4", "5", "6", "7"],
"day": ["1", "3", "5", "5", "4", "20", "20"],
}
actual = (
pa.Table.from_batches(dt.datafusion_read(columns=["value", "day"]))
.sort_by("value")
.to_pydict()
)
assert expected == actual


def test_datafusion_read_with_filter_on_partitioned_column():
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["1", "2", "3"],
"year": ["2020", "2020", "2020"],
"month": ["1", "2", "2"],
"day": ["1", "3", "5"],
}
actual = (
pa.Table.from_batches(dt.datafusion_read(predicate="year = '2020'"))
.sort_by("value")
.to_pydict()
)
assert expected == actual


def test_datafusion_read_with_filter_on_multiple_columns():
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["4", "5"],
"year": ["2021", "2021"],
"month": ["4", "12"],
"day": ["5", "4"],
}
actual = (
pa.Table.from_batches(
dt.datafusion_read(predicate="year = '2021' and value < '6'")
)
.sort_by("value")
.to_pydict()
)
assert expected == actual

0 comments on commit aedeaaa

Please sign in to comment.