From aa3a0c02fc8cf3f82abfaf0ea03af8159d849c15 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Fri, 17 Nov 2023 08:47:45 +0100 Subject: [PATCH] use parametrized constructor to reduce tests --- python/tests/test_writer.py | 89 +++++++++---------------------------- 1 file changed, 21 insertions(+), 68 deletions(-) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 137b35e8f3..4330489e4a 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -304,15 +304,21 @@ def test_write_iterator( assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -@pytest.parametrize("large_dtypes", [True, False]) -@pytest.parametrize("constructor", [ - lambda table: table.to_pyarrow_dataset(), - lambda table: table.to_pyarrow_table(), - lambda table: table.to_pyarrow_table().to_batches()[0] -]) -def test_write_dataset( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, - large_dtypes: bool, constructor +@pytest.mark.parametrize("large_dtypes", [True, False]) +@pytest.mark.parametrize( + "constructor", + [ + lambda table: table.to_pyarrow_dataset(), + lambda table: table.to_pyarrow_table(), + lambda table: table.to_pyarrow_table().to_batches()[0], + ], +) +def test_write_dataset_table_recordbatch( + tmp_path: pathlib.Path, + existing_table: DeltaTable, + sample_data: pa.Table, + large_dtypes: bool, + constructor, ): dataset = constructor(existing_table) @@ -320,72 +326,19 @@ def test_write_dataset( assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -def test_write_dataset_large_types( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - dataset = existing_table.to_pyarrow_dataset() - - write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=True) - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - -def test_write_table( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - dataset = existing_table.to_pyarrow_table() - - write_deltalake(tmp_path, dataset, mode="overwrite") - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - -def test_write_table_large_dtypes( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - dataset = existing_table.to_pyarrow_table() - - write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=True) - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - -def test_write_recordbatch( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - batch = existing_table.to_pyarrow_table().to_batches() - - write_deltalake(tmp_path, batch[0], mode="overwrite") - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - -def test_write_recordbatch_large_dtypes( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - batch = existing_table.to_pyarrow_table().to_batches() - - write_deltalake(tmp_path, batch[0], mode="overwrite", large_dtypes=True) - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - +@pytest.mark.parametrize("large_dtypes", [True, False]) def test_write_recordbatchreader( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table -): - batches = existing_table.to_pyarrow_dataset().to_batches() - reader = RecordBatchReader.from_batches( - existing_table.to_pyarrow_dataset().schema, batches - ) - - write_deltalake(tmp_path, reader, mode="overwrite") - assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - - -def test_write_recordbatchreader_large_dtypes( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, + existing_table: DeltaTable, + sample_data: pa.Table, + large_dtypes: bool, ): batches = existing_table.to_pyarrow_dataset().to_batches() reader = RecordBatchReader.from_batches( existing_table.to_pyarrow_dataset().schema, batches ) - write_deltalake(tmp_path, reader, mode="overwrite", large_dtypes=True) + write_deltalake(tmp_path, reader, mode="overwrite", large_dtypes=large_dtypes) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data