Skip to content

Commit

Permalink
rust bindings for cdf
Browse files Browse the repository at this point in the history
minor rust fixes

minor fix
  • Loading branch information
PatrickJin-db committed Dec 11, 2024
1 parent 49b188b commit 5ce85e1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 34 deletions.
1 change: 1 addition & 0 deletions python/delta-kernel-rust-sharing-wrapper/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cargo.lock
2 changes: 1 addition & 1 deletion python/delta-kernel-rust-sharing-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ crate-type = ["cdylib"]

[dependencies]
arrow = { version = "53.3.0", features = ["pyarrow"] }
delta_kernel = {version = "0.5", features = ["cloud", "default", "default-engine"]}
delta_kernel = { git = "https://github.com/OussamaSaoudi-db/delta-kernel-rs.git", rev = "a25f508", features = ["cloud", "default", "default-engine"]}
openssl = { version = "0.10", features = ["vendored"] }
url = "2"

Expand Down
125 changes: 92 additions & 33 deletions python/delta-kernel-rust-sharing-wrapper/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use std::sync::Arc;

use arrow::compute::filter_record_batch;
use arrow::datatypes::SchemaRef;
use arrow::datatypes::SchemaRef as ArrowSchemaRef;
use arrow::error::ArrowError;
use arrow::pyarrow::PyArrowType;
use delta_kernel::engine::arrow_data::ArrowEngineData;
use arrow::record_batch::{RecordBatch, RecordBatchIterator, RecordBatchReader};

use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor;
use delta_kernel::engine::default::DefaultEngine;
use delta_kernel::scan::ScanResult;
use delta_kernel::{engine::arrow_data::ArrowEngineData, schema::StructType};
use delta_kernel::{DeltaResult, Engine};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use url::Url;

use arrow::record_batch::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use delta_kernel::Engine;

use std::collections::HashMap;

struct KernelError(delta_kernel::Error);
Expand Down Expand Up @@ -78,6 +80,38 @@ impl ScanBuilder {
}
}

fn try_get_schema(schema: &Arc<StructType>) -> Result<ArrowSchemaRef, delta_kernel::Error> {
Ok(Arc::new(schema.as_ref().try_into().map_err(|e| {
delta_kernel::Error::Generic(format!("Could not get result schema: {e}"))
})?))
}

// TODO(patrick.jin): Change return type to RecordBatchIterator<impl Iterator<...>>
fn try_create_record_batch_iter(
results: impl Iterator<Item = DeltaResult<ScanResult>>,
result_schema: ArrowSchemaRef,
) -> RecordBatchIterator<Vec<Result<RecordBatch, ArrowError>>> {
let record_batches: Vec<_> = results
.map(|res| {
let scan_res = res.and_then(|res| Ok((res.full_mask(), res.raw_data?)));
let (mask, data) =
scan_res.map_err(|e| ArrowError::from_external_error(Box::new(e)))?;
let record_batch: RecordBatch = data
.into_any()
.downcast::<ArrowEngineData>()
.map_err(|_| ArrowError::CastError("Couldn't cast to ArrowEngineData".to_string()))?
.into();
if let Some(mask) = mask {
let filtered_batch = filter_record_batch(&record_batch, &mask.into())?;
Ok(filtered_batch)
} else {
Ok(record_batch)
}
})
.collect();
RecordBatchIterator::new(record_batches, result_schema)
}

#[pyclass]
struct Scan(delta_kernel::scan::Scan);

Expand All @@ -87,38 +121,61 @@ impl Scan {
&self,
engine_interface: &PythonInterface,
) -> DeltaPyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let result_schema: SchemaRef =
Arc::new(self.0.schema().as_ref().try_into().map_err(|e| {
delta_kernel::Error::Generic(format!("Could not get result schema: {e}"))
})?);
let results = self.0.execute(engine_interface.0.as_ref())?;
let record_batches: Vec<_> = results
.map(|res| {
let scan_res = res.and_then(|res| Ok((res.full_mask(), res.raw_data?)));
let (mask, data) =
scan_res.map_err(|e| ArrowError::from_external_error(Box::new(e)))?;
let record_batch: RecordBatch = data
.into_any()
.downcast::<ArrowEngineData>()
.map_err(|_| {
ArrowError::CastError("Couldn't cast to ArrowEngineData".to_string())
})?
.into();
if let Some(mask) = mask {
let filtered_batch = filter_record_batch(&record_batch, &mask.into())?;
Ok(filtered_batch)
} else {
Ok(record_batch)
}
})
.collect();
let record_batch_iter = RecordBatchIterator::new(record_batches, result_schema);
let result_schema: ArrowSchemaRef = try_get_schema(self.0.schema())?;
let results = self.0.execute(engine_interface.0.clone())?;
let record_batch_iter = try_create_record_batch_iter(results, result_schema);
Ok(PyArrowType(Box::new(record_batch_iter)))
}
}

#[pyclass]
struct TableChangesScanBuilder(
Option<delta_kernel::table_changes::scan::TableChangesScanBuilder>
);

#[pymethods]
impl TableChangesScanBuilder {
#[new]
#[pyo3(signature = (table, engine_interface, start_version, end_version=None))]
fn new(
table: &Table,
engine_interface: &PythonInterface,
start_version: u64,
end_version: Option<u64>,
) -> DeltaPyResult<TableChangesScanBuilder> {
let table_changes =
table
.0
.table_changes(engine_interface.0.as_ref(), start_version, end_version)?;
Ok(TableChangesScanBuilder(Some(
table_changes.into_scan_builder(),
)))
}

fn build(&mut self) -> DeltaPyResult<TableChangesScan> {
let scan = self.0.take().unwrap().build()?;
Ok(TableChangesScan(scan))
}
}

#[pyclass]
struct TableChangesScan(delta_kernel::table_changes::scan::TableChangesScan);

#[pymethods]
impl TableChangesScan {
fn execute(
&self,
engine_interface: &PythonInterface,
) -> DeltaPyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let result_schema: ArrowSchemaRef = try_get_schema(self.0.schema())?;
let results = self.0.execute(engine_interface.0.clone())?;
let record_batch_iter = try_create_record_batch_iter(results, result_schema);
Ok(PyArrowType(Box::new(record_batch_iter)))
}
}

#[pyclass]
struct PythonInterface(Box<dyn Engine + Send>);
struct PythonInterface(Arc<dyn Engine + Send>);

#[pymethods]
impl PythonInterface {
Expand All @@ -130,7 +187,7 @@ impl PythonInterface {
HashMap::<String, String>::new(),
Arc::new(TokioBackgroundExecutor::new()),
)?;
Ok(PythonInterface(Box::new(client)))
Ok(PythonInterface(Arc::new(client)))
}
}

Expand All @@ -144,5 +201,7 @@ fn delta_kernel_rust_sharing_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Snapshot>()?;
m.add_class::<ScanBuilder>()?;
m.add_class::<Scan>()?;
m.add_class::<TableChangesScanBuilder>()?;
m.add_class::<TableChangesScan>()?;
Ok(())
}

0 comments on commit 5ce85e1

Please sign in to comment.