Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for precision, scale in column schema #646

Merged
merged 20 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
self.schema = pyarrow.schema(
[pyarrow.field(
name,
get_py_arrow_datatype(schema_item["data_type"], self._caps, self.timestamp_timezone),
get_py_arrow_datatype(schema_item, self._caps, self.timestamp_timezone),
nullable=schema_item["nullable"]
) for name, schema_item in columns_schema.items()]
)
Expand Down
30 changes: 23 additions & 7 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Tuple
from typing import Any, Tuple, Optional
from dlt import version
from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema.typing import TColumnType

try:
import pyarrow
Expand All @@ -11,30 +12,33 @@
raise MissingDependencyException("DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet.")


def get_py_arrow_datatype(column_type: str, caps: DestinationCapabilitiesContext, tz: str) -> Any:
def get_py_arrow_datatype(column: TColumnType, caps: DestinationCapabilitiesContext, tz: str) -> Any:
column_type = column["data_type"]
if column_type == "text":
return pyarrow.string()
elif column_type == "double":
return pyarrow.float64()
elif column_type == "bool":
return pyarrow.bool_()
elif column_type == "timestamp":
return get_py_arrow_timestamp(caps.timestamp_precision, tz)
return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz)
elif column_type == "bigint":
return pyarrow.int64()
return get_pyarrow_int(column.get("precision"))
elif column_type == "binary":
return pyarrow.binary()
return pyarrow.binary(column.get("precision") or -1)
elif column_type == "complex":
# return pyarrow.struct([pyarrow.field('json', pyarrow.string())])
return pyarrow.string()
elif column_type == "decimal":
return get_py_arrow_numeric(caps.decimal_precision)
precision, scale = column.get("precision"), column.get("scale")
precision_tuple = (precision, scale) if precision is not None and scale is not None else caps.decimal_precision
return get_py_arrow_numeric(precision_tuple)
elif column_type == "wei":
return get_py_arrow_numeric(caps.wei_precision)
elif column_type == "date":
return pyarrow.date32()
elif column_type == "time":
return get_py_arrow_time(caps.timestamp_precision)
return get_py_arrow_time(column.get("precision") or caps.timestamp_precision)
else:
raise ValueError(column_type)

Expand Down Expand Up @@ -66,3 +70,15 @@ def get_py_arrow_numeric(precision: Tuple[int, int]) -> Any:
return pyarrow.decimal256(*precision)
# for higher precision use max precision and trim scale to leave the most significant part
return pyarrow.decimal256(76, max(0, 76 - (precision[0] - precision[1])))


def get_pyarrow_int(precision: Optional[int]) -> Any:
if precision is None:
return pyarrow.int64()
if precision <= 8:
return pyarrow.int8()
elif precision <= 16:
return pyarrow.int16()
elif precision <= 32:
return pyarrow.int32()
return pyarrow.int64()
9 changes: 7 additions & 2 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@
WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition))


class TColumnSchemaBase(TypedDict, total=False):
class TColumnType(TypedDict, total=False):
data_type: Optional[TDataType]
precision: Optional[int]
scale: Optional[int]


class TColumnSchemaBase(TColumnType, total=False):
"""TypedDict that defines basic properties of a column: name, data type and nullable"""
name: Optional[str]
data_type: Optional[TDataType]
nullable: Optional[bool]


Expand Down
5 changes: 5 additions & 0 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool:
return any(c.get("data_type") == _typ for c in table["columns"].values())


def table_schema_has_type_with_precision(table: TTableSchema, _typ: TDataType) -> bool:
"""Checks if `table` schema contains column with type _typ and precision set"""
return any(c.get("data_type") == _typ and c.get("precision") is not None for c in table["columns"].values())


def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema:
"""Finds top level (without parent) of a `table_name` following the ancestry hierarchy."""
table = tables[table_name]
Expand Down
8 changes: 5 additions & 3 deletions dlt/common/time.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
from typing import Any, Optional, Union, overload # noqa
from typing import Any, Optional, Union, overload, TypeVar # noqa
import datetime # noqa: I251

from dlt.common.pendulum import pendulum, timedelta
Expand Down Expand Up @@ -148,5 +148,7 @@ def to_seconds(td: Optional[TimedeltaSeconds]) -> Optional[float]:
return td


def reduce_pendulum_datetime_precision(value: pendulum.DateTime, microsecond_precision: int) -> pendulum.DateTime:
return value.set(microsecond=value.microsecond // 10**(6 - microsecond_precision) * 10**(6 - microsecond_precision)) # type: ignore
T = TypeVar("T", bound=Union[pendulum.DateTime, pendulum.Time])

def reduce_pendulum_datetime_precision(value: T, microsecond_precision: int) -> T:
return value.replace(microsecond=value.microsecond // 10**(6 - microsecond_precision) * 10**(6 - microsecond_precision)) # type: ignore
8 changes: 8 additions & 0 deletions dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
T = TypeVar("T")
TDict = TypeVar("TDict", bound=DictStrAny)

TKey = TypeVar("TKey")
TValue = TypeVar("TValue")

# row counts
TRowCount = Dict[str, int]

Expand Down Expand Up @@ -457,3 +460,8 @@ def maybe_context(manager: ContextManager[TAny]) -> Iterator[TAny]:
else:
with manager as ctx:
yield ctx


def without_none(d: Mapping[TKey, Optional[TValue]]) -> Mapping[TKey, TValue]:
"""Return a new dict with all `None` values removed"""
return {k: v for k, v in d.items() if v is not None}
97 changes: 57 additions & 40 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from pyathena.formatter import DefaultParameterFormatter, _DEFAULT_FORMATTERS, Formatter, _format_date

from dlt.common import logger
from dlt.common.utils import without_none
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TTableSchema
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.utils import table_schema_has_type
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob
Expand All @@ -31,32 +32,59 @@
from dlt.destinations.typing import DBApiCursor
from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo
from dlt.destinations.athena.configuration import AthenaClientConfiguration
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations import path_utils

SCT_TO_HIVET: Dict[TDataType, str] = {
"complex": "string",
"text": "string",
"double": "double",
"bool": "boolean",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"decimal": "decimal(%i,%i)",
"time": "string"
}

HIVET_TO_SCT: Dict[str, TDataType] = {
"varchar": "text",
"double": "double",
"boolean": "bool",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"varbinary": "binary",
"decimal": "decimal",
}

class AthenaTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"complex": "string",
"text": "string",
"double": "double",
"bool": "boolean",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"time": "string"
}

sct_to_dbt = {
"decimal": "decimal(%i,%i)",
"wei": "decimal(%i,%i)"
}

dbt_to_sct = {
"varchar": "text",
"double": "double",
"boolean": "bool",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"varbinary": "binary",
"decimal": "decimal",
"tinyint": "bigint",
"smallint": "bigint",
"int": "bigint",
}

def to_db_integer_type(self, precision: Optional[int]) -> str:
if precision is None:
return "bigint"
if precision <= 8:
return "tinyint"
elif precision <= 16:
return "smallint"
elif precision <= 32:
return "int"
return "bigint"

def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
for key, val in self.dbt_to_sct.items():
if db_type.startswith(key):
return without_none(dict(data_type=val, precision=precision, scale=scale)) # type: ignore[return-value]
return dict(data_type=None)


# add a formatter for pendulum to be used by pyathen dbapi
Expand Down Expand Up @@ -265,28 +293,17 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None:
super().__init__(schema, config, sql_client)
self.sql_client: AthenaSQLClient = sql_client # type: ignore
self.config: AthenaClientConfiguration = config
self.type_mapper = AthenaTypeMapper(self.capabilities)

def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
# never truncate tables in athena
super().initialize_storage([])

@classmethod
def _to_db_type(cls, sc_t: TDataType) -> str:
if sc_t == "wei":
return SCT_TO_HIVET["decimal"] % cls.capabilities.wei_precision
if sc_t == "decimal":
return SCT_TO_HIVET["decimal"] % cls.capabilities.decimal_precision
return SCT_TO_HIVET[sc_t]

@classmethod
def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
for key, val in HIVET_TO_SCT.items():
if hive_t.startswith(key):
return val
return None
def _from_db_type(self, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
return self.type_mapper.from_db_type(hive_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self._to_db_type(c['data_type'])}"
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}"

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]:

Expand Down
Loading