diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index b632176c5a..b952b39ed2 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -357,11 +357,16 @@ def writer_spec(cls) -> FileWriterSpec: class CsvWriter(DataWriter): def __init__( - self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: str = "," + self, + f: IO[Any], + caps: DestinationCapabilitiesContext = None, + delimiter: str = ",", + bytes_encoding: str = "utf-8", ) -> None: super().__init__(f, caps) self.delimiter = delimiter self.writer: csv.DictWriter[str] = None + self.bytes_encoding = bytes_encoding def write_header(self, columns_schema: TTableSchemaColumns) -> None: self._columns_schema = columns_schema @@ -374,8 +379,37 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: quoting=csv.QUOTE_NONNUMERIC, ) self.writer.writeheader() + # find row items that are of the complex type (could be abstracted out for use in other writers?) + self.complex_indices = [ + i for i, field in columns_schema.items() if field["data_type"] == "complex" + ] + # find row items that are of the complex type (could be abstracted out for use in other writers?) + self.bytes_indices = [ + i for i, field in columns_schema.items() if field["data_type"] == "binary" + ] def write_data(self, rows: Sequence[Any]) -> None: + # convert bytes and json + if self.complex_indices or self.bytes_indices: + for row in rows: + for key in self.complex_indices: + if (value := row.get(key)) is not None: + row[key] = json.dumps(value) + for key in self.bytes_indices: + if (value := row.get(key)) is not None: + # assumed bytes value + try: + row[key] = value.decode(self.bytes_encoding) + except UnicodeError: + raise InvalidDataItem( + "csv", + "object", + f"'{key}' contains bytes that cannot be decoded with" + f" {self.bytes_encoding}. Remove binary columns or replace their" + " content with a hex representation: \\x... while keeping data" + " type as binary.", + ) + self.writer.writerows(rows) # count rows that got written self.items_count += sum(len(row) for row in rows) diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index d72ef982af..4eb94b5ff0 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -33,6 +33,12 @@ info = pipeline.run(some_source(), loader_file_format="csv") * dates are represented as ISO 8601 ## Limitations +**arrow writer** * binary columns are supported only if they contain valid UTF-8 characters * complex (nested, struct) types are not supported + +**csv writer** +* binary columns are supported only if they contain valid UTF-8 characters (easy to add more encodings) +* complex columns dumped with json.dumps +* **None** values are always quoted \ No newline at end of file diff --git a/tests/cases.py b/tests/cases.py index 8885df0c1b..edff3f5c2c 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -19,9 +19,7 @@ ) from dlt.common.schema import TColumnSchema, TTableSchemaColumns - -TArrowFormat = Literal["pandas", "table", "record_batch"] - +from tests.utils import TArrowFormat, TestDataItemFormat, arrow_item_from_pandas # _UUID = "c8209ee7-ee95-4b90-8c9f-f7a0f8b51014" JSON_TYPED_DICT: StrAny = { @@ -276,38 +274,8 @@ def assert_all_data_types_row( assert db_mapping == expected_rows -def arrow_format_from_pandas( - df: Any, - object_format: TArrowFormat, -) -> Any: - from dlt.common.libs.pyarrow import pyarrow as pa - - if object_format == "pandas": - return df - elif object_format == "table": - return pa.Table.from_pandas(df) - elif object_format == "record_batch": - return pa.RecordBatch.from_pandas(df) - raise ValueError("Unknown item type: " + object_format) - - -def arrow_item_from_table( - table: Any, - object_format: TArrowFormat, -) -> Any: - from dlt.common.libs.pyarrow import pyarrow as pa - - if object_format == "pandas": - return table.to_pandas() - elif object_format == "table": - return table - elif object_format == "record_batch": - return table.to_batches()[0] - raise ValueError("Unknown item type: " + object_format) - - def arrow_table_all_data_types( - object_format: TArrowFormat, + object_format: TestDataItemFormat, include_json: bool = True, include_time: bool = True, include_binary: bool = True, @@ -374,7 +342,10 @@ def arrow_table_all_data_types( .drop(columns=["null"]) .to_dict("records") ) - return arrow_format_from_pandas(df, object_format), rows, data + if object_format == "object": + return rows, rows, data + else: + return arrow_item_from_pandas(df, object_format), rows, data def prepare_shuffled_tables() -> Tuple[Any, Any, Any]: @@ -382,7 +353,7 @@ def prepare_shuffled_tables() -> Tuple[Any, Any, Any]: from dlt.common.libs.pyarrow import pyarrow as pa table, _, _ = arrow_table_all_data_types( - "table", + "arrow-table", include_json=False, include_not_normalized_name=False, tz="Europe/Berlin", diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index d5775266f2..a1101fddb1 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -213,7 +213,7 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")): assert s["last_value"] == 2 -@pytest.mark.parametrize("item_type", ["arrow", "pandas"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas"]) def test_nested_cursor_path_arrow_fails(item_type: TestDataItemFormat) -> None: data = [{"data": {"items": [{"created_at": 2}]}}] source_items = data_to_item_format(item_type, data) @@ -708,7 +708,7 @@ def some_data(step, last_timestamp=dlt.sources.incremental("ts")): p.run(r, destination="duckdb") -@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"json"}) +@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"object"}) def test_start_value_set_to_last_value_arrow(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") now = pendulum.now() @@ -1047,7 +1047,7 @@ def some_data( resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt)) # and the data is naive. so it will work as expected with naive datetimes in the result set data = list(resource) - if item_type == "json": + if item_type == "object": # we do not convert data in arrow tables assert data[0]["updated_at"].tzinfo is None @@ -1059,7 +1059,7 @@ def some_data( ) ) data = list(resource) - if item_type == "json": + if item_type == "object": assert data[0]["updated_at"].tzinfo is None # now use naive initial value but data is UTC @@ -1070,7 +1070,7 @@ def some_data( ) ) # will cause invalid comparison - if item_type == "json": + if item_type == "object": with pytest.raises(InvalidStepFunctionArguments): list(resource) else: @@ -1392,7 +1392,7 @@ def descending( for chunk in chunks(count(start=48, step=-1), 10): data = [{"updated_at": i, "package": package} for i in chunk] # print(data) - yield data_to_item_format("json", data) + yield data_to_item_format("object", data) if updated_at.can_close(): out_of_range.append(package) return diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 5239c38de3..61ccc4d5f4 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -46,7 +46,7 @@ def expect_extracted_file( class AssertItems(ItemTransform[TDataItem]): - def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "json") -> None: + def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "object") -> None: self.expected_items = expected_items self.item_type = item_type @@ -56,7 +56,7 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: def data_item_to_list(from_type: TestDataItemFormat, values: List[TDataItem]): - if from_type in ["arrow", "arrow-batch"]: + if from_type in ["arrow-table", "arrow-batch"]: return values[0].to_pylist() elif from_type == "pandas": return values[0].to_dict("records") diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py deleted file mode 100644 index b9b0555f1d..0000000000 --- a/tests/libs/test_arrow_csv_writer.py +++ /dev/null @@ -1,128 +0,0 @@ -import csv -from copy import copy -import pytest -import pyarrow.parquet as pq - -from dlt.common.data_writers.exceptions import InvalidDataItem -from dlt.common.data_writers.writers import ArrowToCsvWriter, ParquetDataWriter -from dlt.common.libs.pyarrow import remove_columns - -from tests.common.data_writers.utils import get_writer -from tests.cases import ( - TABLE_UPDATE_COLUMNS_SCHEMA, - TABLE_ROW_ALL_DATA_TYPES_DATETIMES, - TABLE_ROW_ALL_DATA_TYPES, - arrow_table_all_data_types, -) - - -def test_csv_writer_all_data_fields() -> None: - data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) - - # write parquet and read it - with get_writer(ParquetDataWriter) as pq_writer: - pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) - - with open(pq_writer.closed_files[0].file_path, "rb") as f: - table = pq.read_table(f) - - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) - - with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) - # header + 1 data - assert len(rows) == 2 - - # compare headers - assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) - # compare row - assert len(rows[1]) == len(list(TABLE_ROW_ALL_DATA_TYPES.values())) - # compare none values - for actual, expected, col_name in zip( - rows[1], TABLE_ROW_ALL_DATA_TYPES.values(), TABLE_UPDATE_COLUMNS_SCHEMA.keys() - ): - if expected is None: - assert actual == "", f"{col_name} is not recovered as None" - - # write again with several arrows - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) - writer.write_data_item(table.to_batches(), TABLE_UPDATE_COLUMNS_SCHEMA) - - with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) - # header + 3 data - assert len(rows) == 4 - assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) - assert rows[1] == rows[2] == rows[3] - - # simulate non announced schema change - base_table = remove_columns(table, ["col9_null"]) - base_column_schema = copy(TABLE_UPDATE_COLUMNS_SCHEMA) - base_column_schema.pop("col9_null") - - with pytest.raises(InvalidDataItem): - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item([base_table, table], TABLE_UPDATE_COLUMNS_SCHEMA) - - with pytest.raises(InvalidDataItem): - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) - writer.write_data_item(base_table, TABLE_UPDATE_COLUMNS_SCHEMA) - - # schema change will rotate the file - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item(base_table, base_column_schema) - writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) - - assert len(writer.closed_files) == 2 - - # first file - with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) - # header + 1 data - assert len(rows) == 2 - assert rows[0] == list(base_column_schema.keys()) - # second file - with open(writer.closed_files[1].file_path, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) - # header + 2 data - assert len(rows) == 3 - assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) - - -def test_non_utf8_binary() -> None: - data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) - data["col7"] += b"\x8e" # type: ignore[operator] - - # write parquet and read it - with get_writer(ParquetDataWriter) as pq_writer: - pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) - - with open(pq_writer.closed_files[0].file_path, "rb") as f: - table = pq.read_table(f) - - with pytest.raises(InvalidDataItem) as inv_ex: - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) - assert "Arrow data contains string or binary columns" in str(inv_ex.value) - - -def test_arrow_struct() -> None: - item, _, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) - with pytest.raises(InvalidDataItem): - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_data_item(item, TABLE_UPDATE_COLUMNS_SCHEMA) - - -def test_csv_writer_empty() -> None: - with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: - writer.write_empty_file(TABLE_UPDATE_COLUMNS_SCHEMA) - - with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) - # only header - assert len(rows) == 1 - - assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) diff --git a/tests/libs/test_csv_writer.py b/tests/libs/test_csv_writer.py new file mode 100644 index 0000000000..f588c5bceb --- /dev/null +++ b/tests/libs/test_csv_writer.py @@ -0,0 +1,221 @@ +import csv +from copy import copy +from typing import Any, Dict, Type +import pytest +import pyarrow.csv as acsv +import pyarrow.parquet as pq + +from dlt.common import json +from dlt.common.data_writers.exceptions import InvalidDataItem +from dlt.common.data_writers.writers import ( + ArrowToCsvWriter, + CsvWriter, + DataWriter, + ParquetDataWriter, +) +from dlt.common.libs.pyarrow import remove_columns + +from tests.common.data_writers.utils import get_writer +from tests.cases import ( + TABLE_UPDATE_COLUMNS_SCHEMA, + TABLE_ROW_ALL_DATA_TYPES_DATETIMES, + TABLE_ROW_ALL_DATA_TYPES, + arrow_table_all_data_types, +) +from tests.utils import TestDataItemFormat + + +def test_csv_arrow_writer_all_data_fields() -> None: + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) + + # write parquet and read it + with get_writer(ParquetDataWriter) as pq_writer: + pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(pq_writer.closed_files[0].file_path, "rb") as f: + table = pq.read_table(f) + + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8", newline="") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + + # compare headers + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + # compare row + assert len(rows[1]) == len(list(TABLE_ROW_ALL_DATA_TYPES.values())) + + # TODO: uncomment and fix decimal256 and None that is "" + # with open(writer.closed_files[0].file_path, "br") as f: + # csv_table = acsv.read_csv(f, convert_options=acsv.ConvertOptions(column_types=table.schema)) + # for actual, expected in zip(table.to_pylist(), csv_table.to_pylist()): + # assert actual == expected + + # write again with several arrows + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + writer.write_data_item(table.to_batches(), TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 3 data + assert len(rows) == 4 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + assert rows[1] == rows[2] == rows[3] + + # simulate non announced schema change + base_table = remove_columns(table, ["col9_null"]) + base_column_schema = copy(TABLE_UPDATE_COLUMNS_SCHEMA) + base_column_schema.pop("col9_null") + + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item([base_table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + writer.write_data_item(base_table, TABLE_UPDATE_COLUMNS_SCHEMA) + + # schema change will rotate the file + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(base_table, base_column_schema) + writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + + assert len(writer.closed_files) == 2 + + # first file + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + assert rows[0] == list(base_column_schema.keys()) + # second file + with open(writer.closed_files[1].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 2 data + assert len(rows) == 3 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + + +def test_csv_object_writer_all_data_fields() -> None: + data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + + # always copy data on write (csv writer may modify the data) + with get_writer(CsvWriter, disable_compression=True) as writer: + writer.write_data_item(copy(data), TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8", newline="") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + f.seek(0) + csv_rows = list(csv.DictReader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + + # compare headers + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + # compare row + assert len(rows[1]) == len(list(TABLE_ROW_ALL_DATA_TYPES.values())) + assert_csv_rows(csv_rows[0], TABLE_ROW_ALL_DATA_TYPES_DATETIMES) + + # write again with several tables + with get_writer(CsvWriter, disable_compression=True) as writer: + writer.write_data_item([copy(data), copy(data)], TABLE_UPDATE_COLUMNS_SCHEMA) + writer.write_data_item(copy(data), TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 3 data + assert len(rows) == 4 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + assert rows[1] == rows[2] == rows[3] + + base_data = copy(data) + base_data.pop("col9_null") + base_column_schema = copy(TABLE_UPDATE_COLUMNS_SCHEMA) + base_column_schema.pop("col9_null") + + # schema change will rotate the file + with get_writer(CsvWriter, disable_compression=True) as writer: + writer.write_data_item(copy(base_data), base_column_schema) + writer.write_data_item([copy(data), copy(data)], TABLE_UPDATE_COLUMNS_SCHEMA) + + assert len(writer.closed_files) == 2 + + # first file + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + assert rows[0] == list(base_column_schema.keys()) + # second file + with open(writer.closed_files[1].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 2 data + assert len(rows) == 3 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + + # simulate non announced schema change + # ignored by reader: we'd need to check this per row which will slow it down + with get_writer(CsvWriter, disable_compression=True) as writer: + writer.write_data_item([copy(base_data), copy(data)], base_column_schema) + + +@pytest.mark.parametrize("item_type", ["object", "arrow-table"]) +def test_non_utf8_binary(item_type: TestDataItemFormat) -> None: + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) + data["col7"] += b"\x8e" # type: ignore[operator] + + if item_type == "arrow-table": + # write parquet and read it + with get_writer(ParquetDataWriter) as pq_writer: + pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(pq_writer.closed_files[0].file_path, "rb") as f: + table = pq.read_table(f) + else: + table = data + writer_type: Type[DataWriter] = ArrowToCsvWriter if item_type == "arrow-table" else CsvWriter # type: ignore + + with pytest.raises(InvalidDataItem) as inv_ex: + with get_writer(writer_type, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + assert "Remove binary columns" in str(inv_ex.value) + + +def test_arrow_struct() -> None: + item, _, _ = arrow_table_all_data_types("arrow-table", include_json=True, include_time=False) + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(item, TABLE_UPDATE_COLUMNS_SCHEMA) + + +@pytest.mark.parametrize("item_type", ["object", "arrow-table"]) +def test_csv_writer_empty(item_type: TestDataItemFormat) -> None: + writer_type: Type[DataWriter] = ArrowToCsvWriter if item_type == "arrow-table" else CsvWriter # type: ignore + with get_writer(writer_type, disable_compression=True) as writer: + writer.write_empty_file(TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # only header + assert len(rows) == 1 + + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + + +def assert_csv_rows(csv_row: Dict[str, Any], expected_row: Dict[str, Any]) -> None: + for actual, expected in zip(csv_row.items(), expected_row.values()): + if expected is None: + expected = "" + elif isinstance(expected, dict): + expected = json.dumps(expected) + else: + # writer calls `str` on non string + expected = expected.decode("utf-8") if isinstance(expected, bytes) else str(expected) + assert actual[1] == expected, print( + f"Failed on {actual[0]}: actual: {actual[1]} vs expected: {expected}" + ) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index df9fc80ad0..82ccb24bf1 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -8,15 +8,14 @@ import pandas as pd import dlt -from dlt.common import Decimal from dlt.common import pendulum from dlt.common.time import reduce_pendulum_datetime_precision from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration -from tests.load.pipeline.utils import assert_table, assert_query_data, select_data +from tests.load.pipeline.utils import select_data from tests.pipeline.utils import assert_load_info -from tests.utils import preserve_environ -from tests.cases import arrow_table_all_data_types, TArrowFormat +from tests.utils import TestDataItemFormat, arrow_item_from_pandas, preserve_environ, TArrowFormat +from tests.cases import arrow_table_all_data_types # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -29,9 +28,9 @@ ), ids=lambda x: x.name, ) -@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_load_arrow_item( - item_type: Literal["pandas", "table", "record_batch"], + item_type: TestDataItemFormat, destination_config: DestinationTestConfiguration, ) -> None: # compression must be on for redshift @@ -147,7 +146,7 @@ def some_data(): ), ids=lambda x: x.name, ) -@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"]) def test_parquet_column_names_are_normalized( item_type: TArrowFormat, destination_config: DestinationTestConfiguration ) -> None: @@ -165,13 +164,7 @@ def test_parquet_column_names_are_normalized( "CreatedAt", ], ) - - if item_type == "pandas": - tbl = df - elif item_type == "table": - tbl = pa.Table.from_pandas(df) - elif item_type == "record_batch": - tbl = pa.RecordBatch.from_pandas(df) + tbl = arrow_item_from_pandas(df, item_type) @dlt.resource def some_data(): diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index a0af885484..c88be1fe07 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1,3 +1,4 @@ +import pytest import csv import posixpath from pathlib import Path @@ -9,7 +10,7 @@ from dlt.common.schema.typing import LOADS_TABLE_NAME from tests.cases import arrow_table_all_data_types -from tests.utils import skip_if_not_active +from tests.utils import ALL_TEST_DATA_ITEM_FORMATS, TestDataItemFormat, skip_if_not_active skip_if_not_active("filesystem") @@ -99,26 +100,30 @@ def some_source(): assert len(replace_files) == 1 -def test_pipeline_csv_filesystem_destination() -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_pipeline_csv_filesystem_destination(item_type: TestDataItemFormat) -> None: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["RESTORE_FROM_DESTINATION"] = "False" # store locally os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage" + pipeline = dlt.pipeline( pipeline_name="parquet_test_" + uniq_id(), destination="filesystem", dataset_name="parquet_test_" + uniq_id(), ) - item, _, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) + item, rows, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=True) info = pipeline.run(item, table_name="table", loader_file_format="csv") info.raise_on_failed_jobs() job = info.load_packages[0].jobs["completed_jobs"][0].file_path assert job.endswith("csv") with open(job, "r", encoding="utf-8") as f: - rows = list(csv.reader(f, dialect=csv.unix_dialect)) + csv_rows = list(csv.DictReader(f, dialect=csv.unix_dialect)) # header + 3 data rows - assert len(rows) == 4 + assert len(csv_rows) == 3 + for row, csv_row in zip(rows, csv_rows): + assert row == csv_row def test_pipeline_parquet_filesystem_destination() -> None: diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index ba36f4afab..586fdb0e39 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -907,7 +907,7 @@ def table_1(): write_disposition="merge", ) def table_2(): - yield data_to_item_format("arrow", [{"id": 2}]) + yield data_to_item_format("arrow-table", [{"id": 2}]) @dlt.resource( columns=[{"name": "id", "data_type": "bigint", "nullable": True}], @@ -1017,7 +1017,7 @@ def table_3(make_data=False): # } # for hour in range(0, max_hours) # ] -# data = data_to_item_format("arrow", data) +# data = data_to_item_format("arrow-table", data) # # print(py_arrow_to_table_schema_columns(data[0].schema)) # # print(data) # yield data diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index 50c14e9cda..1f9d24fc55 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -9,6 +9,7 @@ from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables from tests.pipeline.utils import assert_data_table_counts, assert_load_info, load_tables_to_dicts +from tests.utils import TestDataItemFormat @pytest.mark.parametrize( @@ -16,11 +17,22 @@ destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name, ) -def test_postgres_load_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("item_type", ["object", "table"]) +def test_postgres_load_csv( + destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat +) -> None: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + # convert to pylist when loading from objects, this will kick the csv-reader in + if item_type == "object": + table, shuffled_table, shuffled_removed_column = ( + table.to_pylist(), + shuffled_table.to_pylist(), + shuffled_removed_column.to_pylist(), + ) + load_info = pipeline.run( [shuffled_removed_column, shuffled_table, table], table_name="table", @@ -30,6 +42,9 @@ def test_postgres_load_csv_from_arrow(destination_config: DestinationTestConfigu job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path assert job.endswith("csv") assert_data_table_counts(pipeline, {"table": 5432 * 3}) + data = load_tables_to_dicts(pipeline, "table") + print(data["table"][0]["binary"].tobytes()) + # assert @pytest.mark.parametrize( @@ -37,23 +52,32 @@ def test_postgres_load_csv_from_arrow(destination_config: DestinationTestConfigu destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name, ) -def test_postgres_encoded_binary(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("item_type", ["object", "table"]) +def test_postgres_encoded_binary( + destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat +) -> None: import pyarrow os.environ["RESTORE_FROM_DESTINATION"] = "False" blob = hashlib.sha3_256(random.choice(ascii_lowercase).encode()).digest() # encode as \x... which postgres understands blob_table = pyarrow.Table.from_pylist([{"hash": b"\\x" + blob.hex().encode("ascii")}]) + if item_type == "object": + blob_table = blob_table.to_pylist() + print(blob_table) + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) load_info = pipeline.run(blob_table, table_name="table", loader_file_format="csv") assert_load_info(load_info) job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path assert job.endswith("csv") + # assert if column inferred correctly + assert pipeline.default_schema.get_table_columns("table")["hash"]["data_type"] == "binary" data = load_tables_to_dicts(pipeline, "table") # print(bytes(data["table"][0]["hash"])) # data in postgres equals unencoded blob - assert bytes(data["table"][0]["hash"]) == blob + assert data["table"][0]["hash"].tobytes() == blob @pytest.mark.parametrize( @@ -65,7 +89,7 @@ def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfig os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["RESTORE_FROM_DESTINATION"] = "False" pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, _, _ = arrow_table_all_data_types("table", include_json=False) + table, _, _ = arrow_table_all_data_types("arrow-table", include_json=False) load_info = pipeline.run( table.schema.empty_table(), table_name="table", loader_file_format="csv" diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index b16da73868..d9930c19ee 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -14,24 +14,26 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import ( - arrow_format_from_pandas, - arrow_item_from_table, arrow_table_all_data_types, prepare_shuffled_tables, +) +from tests.utils import ( + preserve_environ, TArrowFormat, + arrow_item_from_pandas, + arrow_item_from_table, ) -from tests.utils import preserve_environ @pytest.mark.parametrize( ("item_type", "is_list"), [ ("pandas", False), - ("table", False), - ("record_batch", False), + ("arrow-table", False), + ("arrow-batch", False), ("pandas", True), - ("table", True), - ("record_batch", True), + ("arrow-table", True), + ("arrow-batch", True), ], ) def test_extract_and_normalize(item_type: TArrowFormat, is_list: bool): @@ -112,11 +114,11 @@ def some_data(): ("item_type", "is_list"), [ ("pandas", False), - ("table", False), - ("record_batch", False), + ("arrow-table", False), + ("arrow-batch", False), ("pandas", True), - ("table", True), - ("record_batch", True), + ("arrow-table", True), + ("arrow-batch", True), ], ) def test_normalize_jsonl(item_type: TArrowFormat, is_list: bool): @@ -151,7 +153,7 @@ def some_data(): assert res_item == exp_item -@pytest.mark.parametrize("item_type", ["table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch"]) def test_add_map(item_type: TArrowFormat): item, _, _ = arrow_table_all_data_types(item_type, num_rows=200) @@ -173,7 +175,7 @@ def map_func(item): assert pa.compute.all(pa.compute.greater(result_tbl["int"], 80)).as_py() -@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_extract_normalize_file_rotation(item_type: TArrowFormat) -> None: # do not extract state os.environ["RESTORE_FROM_DESTINATION"] = "False" @@ -205,7 +207,7 @@ def data_frames(): assert len(pipeline.get_load_package_info(load_id).jobs["new_jobs"]) == 10 -@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_arrow_clashing_names(item_type: TArrowFormat) -> None: # # use parquet for dummy os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" @@ -224,7 +226,7 @@ def data_frames(): assert isinstance(py_ex.value.__context__, NameNormalizationClash) -@pytest.mark.parametrize("item_type", ["table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch"]) def test_load_arrow_vary_schema(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") @@ -243,7 +245,7 @@ def test_load_arrow_vary_schema(item_type: TArrowFormat) -> None: pipeline.run(item, table_name="data").raise_on_failed_jobs() -@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" @@ -261,7 +263,7 @@ def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: assert info.row_counts["items"] == len(rows) -@pytest.mark.parametrize("item_type", ["table"]) # , "pandas", "record_batch" +@pytest.mark.parametrize("item_type", ["arrow-table"]) # , "pandas", "arrow-batch" def test_normalize_with_dlt_columns(item_type: TArrowFormat): item, records, _ = arrow_table_all_data_types(item_type, num_rows=5432) os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" @@ -327,7 +329,7 @@ def some_data(): assert schema.tables["some_data"]["columns"]["static_int"]["data_type"] == "bigint" -@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"]) def test_normalize_reorder_columns_separate_packages(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() @@ -378,7 +380,7 @@ def _to_item(table: Any) -> Any: load_info.raise_on_failed_jobs() -@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"]) def test_normalize_reorder_columns_single_package(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" # we do not want to rotate buffer @@ -420,7 +422,7 @@ def _to_item(table: Any) -> Any: pipeline.load().raise_on_failed_jobs() -@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"]) def test_normalize_reorder_columns_single_batch(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" # we do not want to rotate buffer @@ -472,7 +474,7 @@ def _to_item(table: Any) -> Any: pipeline.load().raise_on_failed_jobs() -@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +@pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_empty_arrow(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" @@ -495,7 +497,7 @@ def test_empty_arrow(item_type: TArrowFormat) -> None: empty_df = pd.DataFrame(columns=item.columns) item_resource = dlt.resource( - arrow_format_from_pandas(empty_df, item_type), name="items", write_disposition="replace" + arrow_item_from_pandas(empty_df, item_type), name="items", write_disposition="replace" ) info = pipeline.extract(item_resource) load_id = info.loads_ids[0] diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index 28a5f03fb1..2dba9d7f6d 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -100,7 +100,7 @@ def run_resource( pipeline: Pipeline, resource_fun: Callable[..., DltResource], settings: Any, - item_format: TestDataItemFormat = "json", + item_format: TestDataItemFormat = "object", duplicates: int = 1, ) -> None: for item in settings.keys(): @@ -116,7 +116,7 @@ def run_resource( def source() -> Iterator[DltResource]: for idx in range(duplicates): resource: DltResource = resource_fun(settings.get("resource")) - if item_format != "json": + if item_format != "object": resource._pipe.replace_gen(data_to_item_format(item_format, resource._pipe.gen())) # type: ignore resource.table_name = resource.name yield resource.with_name(resource.name + str(idx)) @@ -181,7 +181,7 @@ def test_new_tables( pipeline.drop_pending_packages() # NOTE: arrow / pandas do not support variants and subtables so we must skip - if item_format == "json": + if item_format == "object": # run add variant column run_resource(pipeline, items_with_variant, full_settings) table_counts = load_table_counts( @@ -243,7 +243,7 @@ def test_new_columns( assert table_counts["items"] == expected_items_count # NOTE: arrow / pandas do not support variants and subtables so we must skip - if item_format == "json": + if item_format == "object": # subtable should work run_resource(pipeline, items_with_subtable, full_settings) table_counts = load_table_counts( diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index 94683e4995..8f736e13d9 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -171,9 +171,9 @@ def load_tables_to_dicts(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[D for table_name in table_names: table_rows = [] columns = p.default_schema.get_table_columns(table_name).keys() - query_columns = ",".join(columns) with p.sql_client() as c: + query_columns = ",".join(map(c.escape_column_name, columns)) f_q_table_name = c.make_qualified_table_name(table_name) query = f"SELECT {query_columns} FROM {f_q_table_name}" with c.execute_query(query) as cur: diff --git a/tests/utils.py b/tests/utils.py index 73e99c3fcd..ffebec7dc5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,11 +76,12 @@ for destination in ACTIVE_DESTINATIONS: assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown active destination {destination}" +TArrowFormat = Literal["pandas", "arrow-table", "arrow-batch"] +"""Possible arrow item formats""" -# possible TDataItem types -TestDataItemFormat = Literal["json", "pandas", "arrow", "arrow-batch"] +TestDataItemFormat = Literal["object", "pandas", "arrow-table", "arrow-batch"] ALL_TEST_DATA_ITEM_FORMATS = get_args(TestDataItemFormat) -"""List with TDataItem formats: json, arrow table/batch / pandas""" +"""List with TDataItem formats: object, arrow table/batch / pandas""" def TEST_DICT_CONFIG_PROVIDER(): @@ -188,7 +189,7 @@ def data_to_item_format( item_format: TestDataItemFormat, data: Union[Iterator[TDataItem], Iterable[TDataItem]] ) -> Any: """Return the given data in the form of pandas, arrow table/batch or json items""" - if item_format == "json": + if item_format == "object": return data import pandas as pd @@ -198,7 +199,7 @@ def data_to_item_format( df = pd.DataFrame(list(data)) if item_format == "pandas": return [df] - elif item_format == "arrow": + elif item_format == "arrow-table": return [pa.Table.from_pandas(df)] elif item_format == "arrow-batch": return [pa.RecordBatch.from_pandas(df)] @@ -225,6 +226,34 @@ def data_item_length(data: TDataItem) -> int: raise TypeError("Unsupported data type.") +def arrow_item_from_pandas( + df: Any, + object_format: TArrowFormat, +) -> Any: + from dlt.common.libs.pyarrow import pyarrow as pa + + if object_format == "pandas": + return df + elif object_format == "arrow-table": + return pa.Table.from_pandas(df) + elif object_format == "arrow-batch": + return pa.RecordBatch.from_pandas(df) + raise ValueError("Unknown item type: " + object_format) + + +def arrow_item_from_table( + table: Any, + object_format: TArrowFormat, +) -> Any: + if object_format == "pandas": + return table.to_pandas() + elif object_format == "arrow-table": + return table + elif object_format == "arrow-batch": + return table.to_batches()[0] + raise ValueError("Unknown item type: " + object_format) + + def init_test_logging(c: RunConfiguration = None) -> None: if not c: c = resolve_configuration(RunConfiguration())