-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add pydantic contracts implementation tests * add tests for removal of normalizer section in schema * add tests for contracts on nested dicts * start working on pyarrow tests * start adding tests of pyarrow normalizer * add pyarrow normalizer tests * add basic arrow tests * merge fixes * update tests --------- Co-authored-by: Marcin Rudolf <[email protected]>
- Loading branch information
Showing
6 changed files
with
610 additions
and
76 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from datetime import timezone, datetime, timedelta # noqa: I251 | ||
from copy import deepcopy | ||
from typing import List, Any | ||
|
||
import pytest | ||
import pyarrow as pa | ||
|
||
from dlt.common import pendulum | ||
from dlt.common.libs.pyarrow import ( | ||
py_arrow_to_table_schema_columns, | ||
from_arrow_scalar, | ||
get_py_arrow_timestamp, | ||
to_arrow_scalar, | ||
get_py_arrow_datatype, | ||
remove_null_columns, | ||
remove_columns, | ||
append_column, | ||
rename_columns, | ||
is_arrow_item, | ||
) | ||
from dlt.common.destination import DestinationCapabilitiesContext | ||
from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA | ||
|
||
|
||
def test_py_arrow_to_table_schema_columns(): | ||
dlt_schema = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) | ||
|
||
caps = DestinationCapabilitiesContext.generic_capabilities() | ||
# The arrow schema will add precision | ||
dlt_schema["col4"]["precision"] = caps.timestamp_precision | ||
dlt_schema["col6"]["precision"], dlt_schema["col6"]["scale"] = caps.decimal_precision | ||
dlt_schema["col11"]["precision"] = caps.timestamp_precision | ||
dlt_schema["col4_null"]["precision"] = caps.timestamp_precision | ||
dlt_schema["col6_null"]["precision"], dlt_schema["col6_null"]["scale"] = caps.decimal_precision | ||
dlt_schema["col11_null"]["precision"] = caps.timestamp_precision | ||
|
||
# Ignoring wei as we can't distinguish from decimal | ||
dlt_schema["col8"]["precision"], dlt_schema["col8"]["scale"] = (76, 0) | ||
dlt_schema["col8"]["data_type"] = "decimal" | ||
dlt_schema["col8_null"]["precision"], dlt_schema["col8_null"]["scale"] = (76, 0) | ||
dlt_schema["col8_null"]["data_type"] = "decimal" | ||
# No json type | ||
dlt_schema["col9"]["data_type"] = "text" | ||
del dlt_schema["col9"]["variant"] | ||
dlt_schema["col9_null"]["data_type"] = "text" | ||
del dlt_schema["col9_null"]["variant"] | ||
|
||
# arrow string fields don't have precision | ||
del dlt_schema["col5_precision"]["precision"] | ||
|
||
# Convert to arrow schema | ||
arrow_schema = pa.schema( | ||
[ | ||
pa.field( | ||
column["name"], | ||
get_py_arrow_datatype(column, caps, "UTC"), | ||
nullable=column["nullable"], | ||
) | ||
for column in dlt_schema.values() | ||
] | ||
) | ||
|
||
result = py_arrow_to_table_schema_columns(arrow_schema) | ||
|
||
# Resulting schema should match the original | ||
assert result == dlt_schema | ||
|
||
|
||
def test_to_arrow_scalar() -> None: | ||
naive_dt = get_py_arrow_timestamp(6, tz=None) | ||
# print(naive_dt) | ||
# naive datetimes are converted as UTC when time aware python objects are used | ||
assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime( | ||
2021, 1, 1, 5, 2, 32 | ||
) | ||
assert to_arrow_scalar( | ||
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt | ||
).as_py() == datetime(2021, 1, 1, 5, 2, 32) | ||
assert to_arrow_scalar( | ||
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), naive_dt | ||
).as_py() == datetime(2021, 1, 1, 5, 2, 32) + timedelta(hours=8) | ||
|
||
# naive datetimes are treated like UTC | ||
utc_dt = get_py_arrow_timestamp(6, tz="UTC") | ||
dt_converted = to_arrow_scalar( | ||
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), utc_dt | ||
).as_py() | ||
assert dt_converted.utcoffset().seconds == 0 | ||
assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) | ||
|
||
berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") | ||
dt_converted = to_arrow_scalar( | ||
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt | ||
).as_py() | ||
# no dst | ||
assert dt_converted.utcoffset().seconds == 60 * 60 | ||
assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) | ||
|
||
|
||
def test_from_arrow_scalar() -> None: | ||
naive_dt = get_py_arrow_timestamp(6, tz=None) | ||
sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt) | ||
|
||
# this value is like UTC | ||
py_dt = from_arrow_scalar(sc_dt) | ||
assert isinstance(py_dt, pendulum.DateTime) | ||
# and we convert to explicit UTC | ||
assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc) | ||
|
||
# converts to UTC | ||
berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") | ||
sc_dt = to_arrow_scalar( | ||
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt | ||
) | ||
py_dt = from_arrow_scalar(sc_dt) | ||
assert isinstance(py_dt, pendulum.DateTime) | ||
assert py_dt.tzname() == "UTC" | ||
assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) | ||
|
||
|
||
def _row_at_index(table: pa.Table, index: int) -> List[Any]: | ||
return [table.column(column_name)[index].as_py() for column_name in table.column_names] | ||
|
||
|
||
@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) | ||
def test_remove_null_columns(pa_type: Any) -> None: | ||
table = pa_type.from_pylist( | ||
[ | ||
{"a": 1, "b": 2, "c": None}, | ||
{"a": 1, "b": None, "c": None}, | ||
] | ||
) | ||
result = remove_null_columns(table) | ||
assert result.column_names == ["a", "b"] | ||
assert _row_at_index(result, 0) == [1, 2] | ||
assert _row_at_index(result, 1) == [1, None] | ||
|
||
|
||
@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) | ||
def test_remove_columns(pa_type: Any) -> None: | ||
table = pa_type.from_pylist( | ||
[ | ||
{"a": 1, "b": 2, "c": 5}, | ||
{"a": 1, "b": 3, "c": 4}, | ||
] | ||
) | ||
result = remove_columns(table, ["b"]) | ||
assert result.column_names == ["a", "c"] | ||
assert _row_at_index(result, 0) == [1, 5] | ||
assert _row_at_index(result, 1) == [1, 4] | ||
|
||
|
||
@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) | ||
def test_append_column(pa_type: Any) -> None: | ||
table = pa_type.from_pylist( | ||
[ | ||
{"a": 1, "b": 2}, | ||
{"a": 1, "b": 3}, | ||
] | ||
) | ||
result = append_column(table, "c", pa.array([5, 6])) | ||
assert result.column_names == ["a", "b", "c"] | ||
assert _row_at_index(result, 0) == [1, 2, 5] | ||
assert _row_at_index(result, 1) == [1, 3, 6] | ||
|
||
|
||
@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) | ||
def test_rename_column(pa_type: Any) -> None: | ||
table = pa_type.from_pylist( | ||
[ | ||
{"a": 1, "b": 2, "c": 5}, | ||
{"a": 1, "b": 3, "c": 4}, | ||
] | ||
) | ||
result = rename_columns(table, ["one", "two", "three"]) | ||
assert result.column_names == ["one", "two", "three"] | ||
assert _row_at_index(result, 0) == [1, 2, 5] | ||
assert _row_at_index(result, 1) == [1, 3, 4] | ||
|
||
|
||
@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) | ||
def test_is_arrow_item(pa_type: Any) -> None: | ||
table = pa_type.from_pylist( | ||
[ | ||
{"a": 1, "b": 2, "c": 5}, | ||
{"a": 1, "b": 3, "c": 4}, | ||
] | ||
) | ||
assert is_arrow_item(table) | ||
assert not is_arrow_item(table.to_pydict()) | ||
assert not is_arrow_item("hello") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import List, Any | ||
|
||
import pyarrow as pa | ||
import pytest | ||
|
||
from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationClash | ||
from dlt.common.normalizers import explicit_normalizers, import_normalizers | ||
from dlt.common.schema.utils import new_column, TColumnSchema | ||
from dlt.common.destination import DestinationCapabilitiesContext | ||
|
||
|
||
def _normalize(table: pa.Table, columns: List[TColumnSchema]) -> pa.Table: | ||
_, naming, _ = import_normalizers(explicit_normalizers()) | ||
caps = DestinationCapabilitiesContext() | ||
columns_schema = {c["name"]: c for c in columns} | ||
return normalize_py_arrow_item(table, columns_schema, naming, caps) | ||
|
||
|
||
def _row_at_index(table: pa.Table, index: int) -> List[Any]: | ||
return [table.column(column_name)[index].as_py() for column_name in table.column_names] | ||
|
||
|
||
def test_quick_return_if_nothing_to_do() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"a": 1, "b": 2}, | ||
] | ||
) | ||
columns = [new_column("a", "bigint"), new_column("b", "bigint")] | ||
result = _normalize(table, columns) | ||
# same object returned | ||
assert result == table | ||
|
||
|
||
def test_pyarrow_reorder_columns() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col_new": "hello", "col1": 1, "col2": "a"}, | ||
] | ||
) | ||
columns = [new_column("col2", "text"), new_column("col1", "bigint")] | ||
result = _normalize(table, columns) | ||
# new columns appear at the end | ||
assert result.column_names == ["col2", "col1", "col_new"] | ||
assert _row_at_index(result, 0) == ["a", 1, "hello"] | ||
|
||
|
||
def test_pyarrow_add_empty_types() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col1": 1}, | ||
] | ||
) | ||
columns = [new_column("col1", "bigint"), new_column("col2", "text")] | ||
result = _normalize(table, columns) | ||
# new columns appear at the end | ||
assert result.column_names == ["col1", "col2"] | ||
assert _row_at_index(result, 0) == [1, None] | ||
assert result.schema.field(1).type == "string" | ||
|
||
|
||
def test_field_normalization_clash() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col^New": "hello", "col_new": 1}, | ||
] | ||
) | ||
with pytest.raises(NameNormalizationClash): | ||
_normalize(table, []) | ||
|
||
|
||
def test_field_normalization() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col^New": "hello", "col2": 1}, | ||
] | ||
) | ||
result = _normalize(table, []) | ||
assert result.column_names == ["col_new", "col2"] | ||
assert _row_at_index(result, 0) == ["hello", 1] | ||
|
||
|
||
def test_default_dlt_columns_not_added() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col1": 1}, | ||
] | ||
) | ||
columns = [ | ||
new_column("_dlt_something", "bigint"), | ||
new_column("_dlt_id", "text"), | ||
new_column("_dlt_load_id", "text"), | ||
new_column("col2", "text"), | ||
new_column("col1", "text"), | ||
] | ||
result = _normalize(table, columns) | ||
# no dlt_id or dlt_load_id columns | ||
assert result.column_names == ["_dlt_something", "col2", "col1"] | ||
assert _row_at_index(result, 0) == [None, None, 1] | ||
|
||
|
||
@pytest.mark.skip(reason="Somehow this does not fail, should we add an exception??") | ||
def test_fails_if_adding_non_nullable_column() -> None: | ||
table = pa.Table.from_pylist( | ||
[ | ||
{"col1": 1}, | ||
] | ||
) | ||
columns = [new_column("col1", "bigint"), new_column("col2", "text", nullable=False)] | ||
with pytest.raises(Exception): | ||
_normalize(table, columns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.