Skip to content

Commit

Permalink
Add some missing tests (#896)
Browse files Browse the repository at this point in the history
* add pydantic contracts implementation tests

* add tests for removal of normalizer section in schema

* add tests for contracts on nested dicts

* start working on pyarrow tests

* start adding tests of pyarrow normalizer

* add pyarrow normalizer tests

* add basic arrow tests

* merge fixes

* update tests

---------

Co-authored-by: Marcin Rudolf <[email protected]>
  • Loading branch information
sh-rp and rudolfix authored Apr 24, 2024
1 parent 6432ed7 commit edcedd5
Show file tree
Hide file tree
Showing 6 changed files with 610 additions and 76 deletions.
Empty file added tests/libs/pyarrow/__init__.py
Empty file.
191 changes: 191 additions & 0 deletions tests/libs/pyarrow/test_pyarrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from datetime import timezone, datetime, timedelta # noqa: I251
from copy import deepcopy
from typing import List, Any

import pytest
import pyarrow as pa

from dlt.common import pendulum
from dlt.common.libs.pyarrow import (
py_arrow_to_table_schema_columns,
from_arrow_scalar,
get_py_arrow_timestamp,
to_arrow_scalar,
get_py_arrow_datatype,
remove_null_columns,
remove_columns,
append_column,
rename_columns,
is_arrow_item,
)
from dlt.common.destination import DestinationCapabilitiesContext
from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA


def test_py_arrow_to_table_schema_columns():
dlt_schema = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA)

caps = DestinationCapabilitiesContext.generic_capabilities()
# The arrow schema will add precision
dlt_schema["col4"]["precision"] = caps.timestamp_precision
dlt_schema["col6"]["precision"], dlt_schema["col6"]["scale"] = caps.decimal_precision
dlt_schema["col11"]["precision"] = caps.timestamp_precision
dlt_schema["col4_null"]["precision"] = caps.timestamp_precision
dlt_schema["col6_null"]["precision"], dlt_schema["col6_null"]["scale"] = caps.decimal_precision
dlt_schema["col11_null"]["precision"] = caps.timestamp_precision

# Ignoring wei as we can't distinguish from decimal
dlt_schema["col8"]["precision"], dlt_schema["col8"]["scale"] = (76, 0)
dlt_schema["col8"]["data_type"] = "decimal"
dlt_schema["col8_null"]["precision"], dlt_schema["col8_null"]["scale"] = (76, 0)
dlt_schema["col8_null"]["data_type"] = "decimal"
# No json type
dlt_schema["col9"]["data_type"] = "text"
del dlt_schema["col9"]["variant"]
dlt_schema["col9_null"]["data_type"] = "text"
del dlt_schema["col9_null"]["variant"]

# arrow string fields don't have precision
del dlt_schema["col5_precision"]["precision"]

# Convert to arrow schema
arrow_schema = pa.schema(
[
pa.field(
column["name"],
get_py_arrow_datatype(column, caps, "UTC"),
nullable=column["nullable"],
)
for column in dlt_schema.values()
]
)

result = py_arrow_to_table_schema_columns(arrow_schema)

# Resulting schema should match the original
assert result == dlt_schema


def test_to_arrow_scalar() -> None:
naive_dt = get_py_arrow_timestamp(6, tz=None)
# print(naive_dt)
# naive datetimes are converted as UTC when time aware python objects are used
assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime(
2021, 1, 1, 5, 2, 32
)
assert to_arrow_scalar(
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt
).as_py() == datetime(2021, 1, 1, 5, 2, 32)
assert to_arrow_scalar(
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), naive_dt
).as_py() == datetime(2021, 1, 1, 5, 2, 32) + timedelta(hours=8)

# naive datetimes are treated like UTC
utc_dt = get_py_arrow_timestamp(6, tz="UTC")
dt_converted = to_arrow_scalar(
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), utc_dt
).as_py()
assert dt_converted.utcoffset().seconds == 0
assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc)

berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin")
dt_converted = to_arrow_scalar(
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt
).as_py()
# no dst
assert dt_converted.utcoffset().seconds == 60 * 60
assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc)


def test_from_arrow_scalar() -> None:
naive_dt = get_py_arrow_timestamp(6, tz=None)
sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt)

# this value is like UTC
py_dt = from_arrow_scalar(sc_dt)
assert isinstance(py_dt, pendulum.DateTime)
# and we convert to explicit UTC
assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc)

# converts to UTC
berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin")
sc_dt = to_arrow_scalar(
datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt
)
py_dt = from_arrow_scalar(sc_dt)
assert isinstance(py_dt, pendulum.DateTime)
assert py_dt.tzname() == "UTC"
assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc)


def _row_at_index(table: pa.Table, index: int) -> List[Any]:
return [table.column(column_name)[index].as_py() for column_name in table.column_names]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_remove_null_columns(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": None},
{"a": 1, "b": None, "c": None},
]
)
result = remove_null_columns(table)
assert result.column_names == ["a", "b"]
assert _row_at_index(result, 0) == [1, 2]
assert _row_at_index(result, 1) == [1, None]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_remove_columns(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
result = remove_columns(table, ["b"])
assert result.column_names == ["a", "c"]
assert _row_at_index(result, 0) == [1, 5]
assert _row_at_index(result, 1) == [1, 4]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_append_column(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2},
{"a": 1, "b": 3},
]
)
result = append_column(table, "c", pa.array([5, 6]))
assert result.column_names == ["a", "b", "c"]
assert _row_at_index(result, 0) == [1, 2, 5]
assert _row_at_index(result, 1) == [1, 3, 6]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_rename_column(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
result = rename_columns(table, ["one", "two", "three"])
assert result.column_names == ["one", "two", "three"]
assert _row_at_index(result, 0) == [1, 2, 5]
assert _row_at_index(result, 1) == [1, 3, 4]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_is_arrow_item(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
assert is_arrow_item(table)
assert not is_arrow_item(table.to_pydict())
assert not is_arrow_item("hello")
111 changes: 111 additions & 0 deletions tests/libs/pyarrow/test_pyarrow_normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import List, Any

import pyarrow as pa
import pytest

from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationClash
from dlt.common.normalizers import explicit_normalizers, import_normalizers
from dlt.common.schema.utils import new_column, TColumnSchema
from dlt.common.destination import DestinationCapabilitiesContext


def _normalize(table: pa.Table, columns: List[TColumnSchema]) -> pa.Table:
_, naming, _ = import_normalizers(explicit_normalizers())
caps = DestinationCapabilitiesContext()
columns_schema = {c["name"]: c for c in columns}
return normalize_py_arrow_item(table, columns_schema, naming, caps)


def _row_at_index(table: pa.Table, index: int) -> List[Any]:
return [table.column(column_name)[index].as_py() for column_name in table.column_names]


def test_quick_return_if_nothing_to_do() -> None:
table = pa.Table.from_pylist(
[
{"a": 1, "b": 2},
]
)
columns = [new_column("a", "bigint"), new_column("b", "bigint")]
result = _normalize(table, columns)
# same object returned
assert result == table


def test_pyarrow_reorder_columns() -> None:
table = pa.Table.from_pylist(
[
{"col_new": "hello", "col1": 1, "col2": "a"},
]
)
columns = [new_column("col2", "text"), new_column("col1", "bigint")]
result = _normalize(table, columns)
# new columns appear at the end
assert result.column_names == ["col2", "col1", "col_new"]
assert _row_at_index(result, 0) == ["a", 1, "hello"]


def test_pyarrow_add_empty_types() -> None:
table = pa.Table.from_pylist(
[
{"col1": 1},
]
)
columns = [new_column("col1", "bigint"), new_column("col2", "text")]
result = _normalize(table, columns)
# new columns appear at the end
assert result.column_names == ["col1", "col2"]
assert _row_at_index(result, 0) == [1, None]
assert result.schema.field(1).type == "string"


def test_field_normalization_clash() -> None:
table = pa.Table.from_pylist(
[
{"col^New": "hello", "col_new": 1},
]
)
with pytest.raises(NameNormalizationClash):
_normalize(table, [])


def test_field_normalization() -> None:
table = pa.Table.from_pylist(
[
{"col^New": "hello", "col2": 1},
]
)
result = _normalize(table, [])
assert result.column_names == ["col_new", "col2"]
assert _row_at_index(result, 0) == ["hello", 1]


def test_default_dlt_columns_not_added() -> None:
table = pa.Table.from_pylist(
[
{"col1": 1},
]
)
columns = [
new_column("_dlt_something", "bigint"),
new_column("_dlt_id", "text"),
new_column("_dlt_load_id", "text"),
new_column("col2", "text"),
new_column("col1", "text"),
]
result = _normalize(table, columns)
# no dlt_id or dlt_load_id columns
assert result.column_names == ["_dlt_something", "col2", "col1"]
assert _row_at_index(result, 0) == [None, None, 1]


@pytest.mark.skip(reason="Somehow this does not fail, should we add an exception??")
def test_fails_if_adding_non_nullable_column() -> None:
table = pa.Table.from_pylist(
[
{"col1": 1},
]
)
columns = [new_column("col1", "bigint"), new_column("col2", "text", nullable=False)]
with pytest.raises(Exception):
_normalize(table, columns)
42 changes: 40 additions & 2 deletions tests/normalize/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dlt.common import json
from dlt.common.destination.capabilities import TLoaderFileFormat
from dlt.common.schema.schema import Schema
from dlt.common.schema.utils import new_table
from dlt.common.storages.exceptions import SchemaNotFoundError
from dlt.common.typing import StrAny
from dlt.common.data_types import TDataType
Expand Down Expand Up @@ -601,15 +602,15 @@ def extract_and_normalize_cases(normalize: Normalize, cases: Sequence[str]) -> s
return normalize_pending(normalize)


def normalize_pending(normalize: Normalize) -> str:
def normalize_pending(normalize: Normalize, schema: Schema = None) -> str:
# pool not required for map_single
load_ids = normalize.normalize_storage.extracted_packages.list_packages()
assert len(load_ids) == 1, "Only one package allowed or rewrite tests"
for load_id in load_ids:
normalize._step_info_start_load_id(load_id)
normalize.load_storage.new_packages.create_package(load_id)
# read schema from package
schema = normalize.normalize_storage.extracted_packages.load_schema(load_id)
schema = schema or normalize.normalize_storage.extracted_packages.load_schema(load_id)
# get files
schema_files = normalize.normalize_storage.extracted_packages.list_new_jobs(load_id)
# normalize without pool
Expand Down Expand Up @@ -708,3 +709,40 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType)
event_schema = load_storage.normalized_packages.load_schema(loads[0])
# in raw normalize timestamp column must not be coerced to timestamp
assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type


def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: Normalize) -> None:
extract_cases(
raw_normalize,
[
"event.event.user_load_1",
],
)
load_ids = raw_normalize.normalize_storage.extracted_packages.list_packages()
assert len(load_ids) == 1
extracted_schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_ids[0])

# add some normalizer blocks
extracted_schema.tables["event"] = new_table("event")
extracted_schema.tables["event__parse_data__intent_ranking"] = new_table(
"event__parse_data__intent_ranking"
)
extracted_schema.tables["event__random_table"] = new_table("event__random_table")

# add x-normalizer info (and other block to control)
extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore
extracted_schema.tables["event"]["x-other-info"] = "blah" # type: ignore
extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = {"seen-data": True, "random-entry": 1234} # type: ignore
extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore

normalize_pending(raw_normalize, extracted_schema)
schema = raw_normalize.schema_storage.load_schema("event")
# seen data gets added, schema settings get removed
assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} # type: ignore
assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { # type: ignore
"seen-data": True,
"random-entry": 1234,
}
# no data seen here, so seen-data is not set and evolve settings stays until first data is seen
assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} # type: ignore
assert "x-other-info" in schema.tables["event"]
Loading

0 comments on commit edcedd5

Please sign in to comment.