Skip to content

Commit

Permalink
rename and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Nov 17, 2023
1 parent 4f8f359 commit 284f8c7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 59 deletions.
6 changes: 3 additions & 3 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def write_to_deltalake(
mode: str,
max_rows_per_group: int,
overwrite_schema: bool,
name: Optional[str],
description: Optional[str],
configuration: Optional[Mapping[str, Optional[str]]],
_name: Optional[str],
_description: Optional[str],
_configuration: Optional[Mapping[str, Optional[str]]],
storage_options: Optional[Dict[str, str]],
) -> None: ...
def batch_distinct(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: ...
Expand Down
24 changes: 16 additions & 8 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import batch_distinct
from ._internal import write_new_deltalake as _write_new_deltalake
from ._internal import write_to_deltalake as _write_to_deltalake
from ._internal import write_new_deltalake as write_deltalake_pyarrow
from ._internal import write_to_deltalake as write_deltalake_rust
from .exceptions import DeltaProtocolError, TableNotFoundError
from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable

Expand Down Expand Up @@ -179,6 +179,11 @@ def write_deltalake(
if table is not None and mode == "ignore":
return

if mode == "overwrite" and overwrite_schema:
raise NotImplementedError(
"The rust engine writer does not yet support schema evolution."
)

if isinstance(data, RecordBatchReader):
batch_iter = data
elif isinstance(data, pa.RecordBatch):
Expand All @@ -188,7 +193,10 @@ def write_deltalake(
elif isinstance(data, ds.Dataset):
batch_iter = data.scanner().to_reader()
elif isinstance(data, pd.DataFrame):
batch_iter = pa.Table.from_pandas(data).to_reader()
if schema is not None:
batch_iter = pa.Table.from_pandas(data, schema).to_reader()
else:
batch_iter = pa.Table.from_pandas(data).to_reader()
else:
batch_iter = data

Expand All @@ -199,16 +207,16 @@ def write_deltalake(
raise ValueError("You must provide schema if data is Iterable")

data = RecordBatchReader.from_batches(schema, (batch for batch in batch_iter))
_write_to_deltalake(
write_deltalake_rust(
table_uri=table_uri,
data=data,
partition_by=partition_by,
mode=mode,
max_rows_per_group=max_rows_per_group,
overwrite_schema=overwrite_schema,
name=name,
description=description,
configuration=configuration,
_name=name,
_description=description,
_configuration=configuration,
storage_options=storage_options,
)
if table:
Expand Down Expand Up @@ -412,7 +420,7 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
)

if table is None:
_write_new_deltalake(
write_deltalake_pyarrow(
table_uri,
schema,
add_actions,
Expand Down
31 changes: 3 additions & 28 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ use deltalake::DeltaTableBuilder;
use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyFrozenSet, PyType};
use serde_json::Value;

use crate::error::DeltaProtocolError;
use crate::error::PythonError;
Expand Down Expand Up @@ -1141,40 +1140,18 @@ impl From<&PyAddAction> for Add {
fn write_to_deltalake(
table_uri: String,
data: PyArrowType<ArrowArrayStreamReader>,
// schema: Option<PyArrowType<ArrowSchema>>, // maybe do the schema casting on python side
mode: String,
max_rows_per_group: i64,
overwrite_schema: bool,
partition_by: Option<Vec<String>>,
name: Option<String>,
description: Option<String>,
configuration: Option<HashMap<String, Option<String>>>,
_name: Option<String>,
_description: Option<String>,
_configuration: Option<HashMap<String, Option<String>>>,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<()> {
let batches = data.0.map(|batch| batch.unwrap()).collect::<Vec<_>>();
let save_mode = save_mode_from_str(&mode)?;

let mut metadata: HashMap<String, Value> = HashMap::new();

if let Some(name) = name {
metadata.insert("name".to_string(), name.into());
}

if let Some(description) = description {
metadata.insert("description".to_string(), description.into());
}

if let Some(configuration) = configuration {
metadata.insert("configuration".to_string(), json!(configuration));
}

// // This should be done when the table can not be loaded ...
// match save_mode {
// SaveMode::Ignore => {
// return Ok(())
// }
// _ => ()
// }
let options = storage_options.clone().unwrap_or_default();
let table = rt()?
.block_on(DeltaOps::try_from_uri_with_storage_options(
Expand All @@ -1186,7 +1163,6 @@ fn write_to_deltalake(
.write(batches)
.with_save_mode(save_mode)
.with_overwrite_schema(overwrite_schema)
.with_metadata(metadata)
.with_write_batch_size(max_rows_per_group as usize);

if let Some(partition_columns) = partition_by {
Expand All @@ -1200,7 +1176,6 @@ fn write_to_deltalake(
Ok(())
}

use serde_json::json;
#[pyfunction]
#[allow(clippy::too_many_arguments)]
fn write_new_deltalake(
Expand Down
63 changes: 43 additions & 20 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,29 @@ def test_update_schema(existing_table: DeltaTable):
assert existing_table.schema().to_pyarrow() == new_data.schema


# def test_update_schema_rust_writer(existing_table: DeltaTable): # Test fails
# new_data = pa.table({"x": pa.array([1, 2, 3])})

# with pytest.raises(DeltaError):
# write_deltalake(
# existing_table,
# new_data,
# mode="append",
# overwrite_schema=True,
# engine="rust",
# )
def test_update_schema_rust_writer(existing_table: DeltaTable): # Test fails
new_data = pa.table({"x": pa.array([1, 2, 3])})

# write_deltalake(
# existing_table, new_data, mode="overwrite", overwrite_schema=True, engine="rust"
# )
with pytest.raises(DeltaError):
write_deltalake(
existing_table,
new_data,
mode="append",
overwrite_schema=True,
engine="rust",
)
with pytest.raises(NotImplementedError):
write_deltalake(
existing_table,
new_data,
mode="overwrite",
overwrite_schema=True,
engine="rust",
)

# read_data = existing_table.to_pyarrow_table()
# assert new_data == read_data
# assert existing_table.schema().to_pyarrow() == new_data.schema
read_data = existing_table.to_pyarrow_table()
assert new_data == read_data
assert existing_table.schema().to_pyarrow() == new_data.schema


@pytest.mark.parametrize("engine", ["pyarrow", "rust"])
Expand All @@ -185,15 +189,34 @@ def test_local_path(
assert table == sample_data


@pytest.mark.parametrize("engine", ["pyarrow", "rust"])
def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table, engine):
@pytest.mark.skip(reason="Waiting on support with create matadata during write")
def test_roundtrip_metadata_rust(tmp_path: pathlib.Path, sample_data: pa.Table, engine):
write_deltalake(
tmp_path,
sample_data,
name="test_name",
description="test_desc",
configuration={"delta.appendOnly": "false", "foo": "bar"},
engine=engine,
engine="rust",
)

delta_table = DeltaTable(tmp_path)

metadata = delta_table.metadata()

assert metadata.name == "test_name"
assert metadata.description == "test_desc"
assert metadata.configuration == {"delta.appendOnly": "false", "foo": "bar"}


def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table):
write_deltalake(
tmp_path,
sample_data,
name="test_name",
description="test_desc",
configuration={"delta.appendOnly": "false", "foo": "bar"},
engine="pyarrow",
)

delta_table = DeltaTable(tmp_path)
Expand Down

0 comments on commit 284f8c7

Please sign in to comment.