Skip to content

Commit

Permalink
Precision in parquet writer
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Sep 23, 2023
1 parent c0cefa0 commit 4021987
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
self.schema = pyarrow.schema(
[pyarrow.field(
name,
get_py_arrow_datatype(schema_item["data_type"], self._caps, self.timestamp_timezone),
get_py_arrow_datatype(schema_item, self._caps, self.timestamp_timezone),
nullable=schema_item["nullable"]
) for name, schema_item in columns_schema.items()]
)
Expand Down
30 changes: 23 additions & 7 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Tuple
from typing import Any, Tuple, Optional
from dlt import version
from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema.typing import TColumnType

try:
import pyarrow
Expand All @@ -11,30 +12,33 @@
raise MissingDependencyException("DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet.")


def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext, tz: str) -> Any:
def get_py_arrow_datatype(column: TColumnType, caps: DestinationCapabilitiesContext, tz: str) -> Any:
column_type = column["data_type"]
if column_type == "text":
return pyarrow.string()
elif column_type == "double":
return pyarrow.float64()
elif column_type == "bool":
return pyarrow.bool_()
elif column_type == "timestamp":
return get_py_arrow_timestamp(caps.timestamp_precision, tz)
return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz)
elif column_type == "bigint":
return pyarrow.int64()
return get_pyarrow_int(column.get("precision"))
elif column_type == "binary":
return pyarrow.binary()
return pyarrow.binary(column.get("precision") or -1)
elif column_type == "complex":
# return pyarrow.struct([pyarrow.field('json', pyarrow.string())])
return pyarrow.string()
elif column_type == "decimal":
return get_py_arrow_numeric(caps.decimal_precision)
precision, scale = column.get("precision"), column.get("scale")
precision_tuple = (precision, scale) if precision is not None and scale is not None else caps.decimal_precision
return get_py_arrow_numeric(precision_tuple)
elif column_type == "wei":
return get_py_arrow_numeric(caps.wei_precision)
elif column_type == "date":
return pyarrow.date32()
elif column_type == "time":
return get_py_arrow_time(caps.timestamp_precision)
return get_py_arrow_time(column.get("precision") or caps.timestamp_precision)
else:
raise ValueError(column_type)

Expand Down Expand Up @@ -66,3 +70,15 @@ def get_py_arrow_numeric(precision: Tuple[int, int]) -> Any:
return pyarrow.decimal256(*precision)
# for higher precision use max precision and trim scale to leave the most significant part
return pyarrow.decimal256(76, max(0, 76 - (precision[0] - precision[1])))


def get_pyarrow_int(precision: Optional[int]) -> Any:
if precision is None:
return pyarrow.int64()
if precision <= 8:
return pyarrow.int8()
elif precision <= 16:
return pyarrow.int16()
elif precision <= 32:
return pyarrow.int32()
return pyarrow.int64()
26 changes: 21 additions & 5 deletions tests/common/test_data_writers/test_parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,23 @@ def test_parquet_writer_all_data_fields() -> None:

data = dict(TABLE_ROW_ALL_DATA_TYPES)
# fix dates to use pendulum
data["col4"] = ensure_pendulum_datetime(data["col4"])
data["col10"] = ensure_pendulum_date(data["col10"])
data["col11"] = pendulum.Time.fromisoformat(data["col11"])
data["col4_precision"] = ensure_pendulum_datetime(data["col4_precision"])
data["col11_precision"] = pendulum.Time.fromisoformat(data["col11_precision"])
data["col4"] = ensure_pendulum_datetime(data["col4"]) # type: ignore[arg-type]
data["col10"] = ensure_pendulum_date(data["col10"]) # type: ignore[arg-type]
data["col11"] = pendulum.Time.fromisoformat(data["col11"]) # type: ignore[arg-type]
data["col4_precision"] = ensure_pendulum_datetime(data["col4_precision"]) # type: ignore[arg-type]
data["col11_precision"] = pendulum.Time.fromisoformat(data["col11_precision"]) # type: ignore[arg-type]

with get_writer("parquet") as writer:
writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA)

# We want to test precision for these fields is trimmed to millisecond
data["col4_precision"] = data["col4_precision"].replace( # type: ignore[attr-defined]
microsecond=int(str(data["col4_precision"].microsecond)[:3] + "000") # type: ignore[attr-defined]
)
data["col11_precision"] = data["col11_precision"].replace( # type: ignore[attr-defined]
microsecond=int(str(data["col11_precision"].microsecond)[:3] + "000") # type: ignore[attr-defined]
)

with open(writer.closed_files[0], "rb") as f:
table = pq.read_table(f)
for key, value in data.items():
Expand All @@ -108,6 +116,14 @@ def test_parquet_writer_all_data_fields() -> None:
actual = ensure_pendulum_datetime(actual)
assert actual == value

assert table.schema.field("col1_precision").type == pa.int16()
# flavor=spark only writes ns precision timestamp, so this is expected
assert table.schema.field("col4_precision").type == pa.timestamp("ns")
assert table.schema.field("col5_precision").type == pa.string()
assert table.schema.field("col6_precision").type == pa.decimal128(6, 2)
assert table.schema.field("col7_precision").type == pa.binary(19)
assert table.schema.field("col11_precision").type == pa.time32("ms")


def test_parquet_writer_items_file_rotation() -> None:
columns = {
Expand Down

0 comments on commit 4021987

Please sign in to comment.