Skip to content

Commit

Permalink
Add test for more write inputs with large types
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Nov 11, 2023
1 parent bc46b27 commit b669358
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/deltalake/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def convert_pyarrow_table(data: pa.Table, large_dtypes: bool) -> pa.RecordBatchR
def convert_pyarrow_dataset(
data: ds.Dataset, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema"""
"""Converts a PyArrow dataset to a PyArrow RecordBatchReader, schema is kept aside and used during write"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = data.replace_schema(schema).scanner().to_reader()
data = data.scanner().to_reader()
return data, schema
40 changes: 40 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ def test_write_dataset(
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_dataset_large_types(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
dataset = existing_table.to_pyarrow_dataset()

write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=True)
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_table(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
Expand All @@ -322,6 +331,15 @@ def test_write_table(
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_table_large_dtypes(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
dataset = existing_table.to_pyarrow_table()

write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=True)
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_recordbatch(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
Expand All @@ -332,6 +350,16 @@ def test_write_recordbatch(
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_recordbatch_large_dtypes(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
batch = existing_table.to_pyarrow_table().to_batches()
print(len(batch))

write_deltalake(tmp_path, batch[0], mode="overwrite", large_dtypes=True)
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_recordbatchreader(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
Expand All @@ -344,6 +372,18 @@ def test_write_recordbatchreader(
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_write_recordbatchreader_large_dtypes(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
batches = existing_table.to_pyarrow_dataset().to_batches()
reader = RecordBatchReader.from_batches(
existing_table.to_pyarrow_dataset().schema, batches
)

write_deltalake(tmp_path, reader, mode="overwrite", large_dtypes=True)
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data


def test_writer_partitioning(tmp_path: pathlib.Path):
test_strings = ["a=b", "hello world", "hello%20world"]
data = pa.table(
Expand Down

0 comments on commit b669358

Please sign in to comment.