diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index fdffd3dc30..37b337d942 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -218,7 +218,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: self.schema = pyarrow.schema( [pyarrow.field( name, - get_py_arrow_datatype(schema_item["data_type"], self._caps, self.timestamp_timezone), + get_py_arrow_datatype(schema_item, self._caps, self.timestamp_timezone), nullable=schema_item["nullable"] ) for name, schema_item in columns_schema.items()] ) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 151baa40fe..cb34856406 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,8 +1,9 @@ -from typing import Any, Tuple +from typing import Any, Tuple, Optional from dlt import version from dlt.common.exceptions import MissingDependencyException from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema.typing import TColumnType try: import pyarrow @@ -11,7 +12,8 @@ raise MissingDependencyException("DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet.") -def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext, tz: str) -> Any: +def get_py_arrow_datatype(column: TColumnType, caps: DestinationCapabilitiesContext, tz: str) -> Any: + column_type = column["data_type"] if column_type == "text": return pyarrow.string() elif column_type == "double": @@ -19,22 +21,24 @@ def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(caps.timestamp_precision, tz) + return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) elif column_type == "bigint": - return pyarrow.int64() + return get_pyarrow_int(column.get("precision")) elif column_type == "binary": - return pyarrow.binary() + return pyarrow.binary(column.get("precision") or -1) elif column_type == "complex": # return pyarrow.struct([pyarrow.field('json', pyarrow.string())]) return pyarrow.string() elif column_type == "decimal": - return get_py_arrow_numeric(caps.decimal_precision) + precision, scale = column.get("precision"), column.get("scale") + precision_tuple = (precision, scale) if precision is not None and scale is not None else caps.decimal_precision + return get_py_arrow_numeric(precision_tuple) elif column_type == "wei": return get_py_arrow_numeric(caps.wei_precision) elif column_type == "date": return pyarrow.date32() elif column_type == "time": - return get_py_arrow_time(caps.timestamp_precision) + return get_py_arrow_time(column.get("precision") or caps.timestamp_precision) else: raise ValueError(column_type) @@ -66,3 +70,15 @@ def get_py_arrow_numeric(precision: Tuple[int, int]) -> Any: return pyarrow.decimal256(*precision) # for higher precision use max precision and trim scale to leave the most significant part return pyarrow.decimal256(76, max(0, 76 - (precision[0] - precision[1]))) + + +def get_pyarrow_int(precision: Optional[int]) -> Any: + if precision is None: + return pyarrow.int64() + if precision <= 8: + return pyarrow.int8() + elif precision <= 16: + return pyarrow.int16() + elif precision <= 32: + return pyarrow.int32() + return pyarrow.int64() diff --git a/tests/common/test_data_writers/test_parquet_writer.py b/tests/common/test_data_writers/test_parquet_writer.py index bc94d8b8fb..54c2311ec9 100644 --- a/tests/common/test_data_writers/test_parquet_writer.py +++ b/tests/common/test_data_writers/test_parquet_writer.py @@ -90,15 +90,23 @@ def test_parquet_writer_all_data_fields() -> None: data = dict(TABLE_ROW_ALL_DATA_TYPES) # fix dates to use pendulum - data["col4"] = ensure_pendulum_datetime(data["col4"]) - data["col10"] = ensure_pendulum_date(data["col10"]) - data["col11"] = pendulum.Time.fromisoformat(data["col11"]) - data["col4_precision"] = ensure_pendulum_datetime(data["col4_precision"]) - data["col11_precision"] = pendulum.Time.fromisoformat(data["col11_precision"]) + data["col4"] = ensure_pendulum_datetime(data["col4"]) # type: ignore[arg-type] + data["col10"] = ensure_pendulum_date(data["col10"]) # type: ignore[arg-type] + data["col11"] = pendulum.Time.fromisoformat(data["col11"]) # type: ignore[arg-type] + data["col4_precision"] = ensure_pendulum_datetime(data["col4_precision"]) # type: ignore[arg-type] + data["col11_precision"] = pendulum.Time.fromisoformat(data["col11_precision"]) # type: ignore[arg-type] with get_writer("parquet") as writer: writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + # We want to test precision for these fields is trimmed to millisecond + data["col4_precision"] = data["col4_precision"].replace( # type: ignore[attr-defined] + microsecond=int(str(data["col4_precision"].microsecond)[:3] + "000") # type: ignore[attr-defined] + ) + data["col11_precision"] = data["col11_precision"].replace( # type: ignore[attr-defined] + microsecond=int(str(data["col11_precision"].microsecond)[:3] + "000") # type: ignore[attr-defined] + ) + with open(writer.closed_files[0], "rb") as f: table = pq.read_table(f) for key, value in data.items(): @@ -108,6 +116,14 @@ def test_parquet_writer_all_data_fields() -> None: actual = ensure_pendulum_datetime(actual) assert actual == value + assert table.schema.field("col1_precision").type == pa.int16() + # flavor=spark only writes ns precision timestamp, so this is expected + assert table.schema.field("col4_precision").type == pa.timestamp("ns") + assert table.schema.field("col5_precision").type == pa.string() + assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) + assert table.schema.field("col7_precision").type == pa.binary(19) + assert table.schema.field("col11_precision").type == pa.time32("ms") + def test_parquet_writer_items_file_rotation() -> None: columns = {